many2many
e339bfab07
可将chunking_by_token_size的实现改为: return chunking_by_chinese_splitter(content, overlap_token_size, max_token_size, tiktoken_model)
126 lines
5.0 KiB
Python
126 lines
5.0 KiB
Python
import re
|
||
from typing import Any, List, Optional
|
||
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from typing import Any
|
||
|
||
# CHUNK_SIZE是指在处理大型数据集时,将数据分成多个小块(chunk)时,每个小块的大小。这样做可以有效地管理内存使用,避免一次性加载过多数据导致内存溢出。在这里应根据大模型API请求的上下文 tokens 大小进行设置。
|
||
# CHUNK_OVERLAP是指在处理文本数据时,将文本分成多个小块(chunk)时,相邻块之间重叠的部分。这样做可以确保在分块处理时不会丢失重要信息,特别是在进行文本分类、实体识别等任务时,有助于提高模型的准确性和连贯性。
|
||
DEFAULT_CHUNK_SIZE = 2500 # tokens
|
||
DEFAULT_CHUNK_OVERLAP = 300 # tokens
|
||
|
||
|
||
def chunking_by_chinese_splitter(
|
||
content: str,
|
||
overlap_token_size=128, # 重叠部分的token数量
|
||
max_token_size=1024, # 每个chunk的最大token数量
|
||
tiktoken_model="gpt-4o" # 使用的tokenizer模型名称
|
||
):
|
||
"""将长文本按照token数量切分成多个重叠的chunk
|
||
|
||
Args:
|
||
content (str): 需要切分的原始文本内容
|
||
overlap_token_size (int, optional): 相邻chunk之间的重叠token数. Defaults to 128.
|
||
max_token_size (int, optional): 每个chunk的最大token数. Defaults to 1024.
|
||
tiktoken_model (str, optional): 使用的tokenizer模型. Defaults to "gpt-4o".
|
||
|
||
Returns:
|
||
list[dict]: 包含切分后的chunk列表,每个chunk包含以下字段:
|
||
- tokens: chunk实际包含的token数
|
||
- content: chunk的文本内容
|
||
- chunk_order_index: chunk的序号
|
||
"""
|
||
text_splitter = ChineseRecursiveTextSplitter(keep_separator=True, is_separator_regex=True, chunk_size=max_token_size, chunk_overlap=overlap_token_size)
|
||
# 切分文本
|
||
chunks = text_splitter.split_text(content)
|
||
# 构建返回结果
|
||
result = []
|
||
for i, chunk in enumerate(chunks):
|
||
result.append({
|
||
"tokens": len(chunk),
|
||
"content": chunk,
|
||
"chunk_order_index": i
|
||
})
|
||
|
||
return result
|
||
|
||
|
||
def _split_text_with_regex_from_end(
|
||
text: str, separator: str, keep_separator: bool
|
||
) -> List[str]:
|
||
if separator:
|
||
if keep_separator:
|
||
# 模式中的括号会保留结果中的分隔符。
|
||
_splits = re.split(f"({separator})", text)
|
||
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
|
||
if len(_splits) % 2 == 1:
|
||
splits += _splits[-1:]
|
||
else:
|
||
splits = re.split(separator, text)
|
||
else:
|
||
splits = list(text)
|
||
return [s for s in splits if s != ""]
|
||
|
||
|
||
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
||
def __init__(
|
||
self,
|
||
separators: Optional[List[str]] = None,
|
||
keep_separator: bool = True,
|
||
is_separator_regex: bool = True,
|
||
**kwargs: Any,
|
||
) -> None:
|
||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||
self._separators = separators or [
|
||
r"\n\n",
|
||
r"\n",
|
||
r"。|!|?",
|
||
r"\.\s|\!\s|\?\s",
|
||
r";|;\s",
|
||
r",|,\s",
|
||
]
|
||
self._is_separator_regex = is_separator_regex
|
||
|
||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||
"""拆分传入的文本并返回处理后的块。"""
|
||
final_chunks = []
|
||
# 获取适当的分隔符以使用
|
||
separator = separators[-1]
|
||
new_separators = []
|
||
for i, _s in enumerate(separators):
|
||
_separator = _s if self._is_separator_regex else re.escape(_s)
|
||
if _s == "":
|
||
separator = _s
|
||
break
|
||
if re.search(_separator, text):
|
||
separator = _s
|
||
new_separators = separators[i + 1:]
|
||
break
|
||
|
||
_separator = separator if self._is_separator_regex else re.escape(separator)
|
||
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
|
||
|
||
_good_splits = []
|
||
_separator = "" if self._keep_separator else separator
|
||
for s in splits:
|
||
if self._length_function(s) < self._chunk_size:
|
||
_good_splits.append(s)
|
||
else:
|
||
if _good_splits:
|
||
merged_text = self._merge_splits(_good_splits, _separator)
|
||
final_chunks.extend(merged_text)
|
||
_good_splits = []
|
||
if not new_separators:
|
||
final_chunks.append(s)
|
||
else:
|
||
other_info = self._split_text(s, new_separators)
|
||
final_chunks.extend(other_info)
|
||
if _good_splits:
|
||
merged_text = self._merge_splits(_good_splits, _separator)
|
||
final_chunks.extend(merged_text)
|
||
return [
|
||
re.sub(r"\n{2,}", "\n", chunk.strip())
|
||
for chunk in final_chunks
|
||
if chunk.strip() != ""
|
||
]
|