126 lines
5.0 KiB
Python
126 lines
5.0 KiB
Python
|
import re
|
|||
|
from typing import Any, List, Optional
|
|||
|
|
|||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
|
from typing import Any
|
|||
|
|
|||
|
# CHUNK_SIZE是指在处理大型数据集时,将数据分成多个小块(chunk)时,每个小块的大小。这样做可以有效地管理内存使用,避免一次性加载过多数据导致内存溢出。在这里应根据大模型API请求的上下文 tokens 大小进行设置。
|
|||
|
# CHUNK_OVERLAP是指在处理文本数据时,将文本分成多个小块(chunk)时,相邻块之间重叠的部分。这样做可以确保在分块处理时不会丢失重要信息,特别是在进行文本分类、实体识别等任务时,有助于提高模型的准确性和连贯性。
|
|||
|
DEFAULT_CHUNK_SIZE = 2500 # tokens
|
|||
|
DEFAULT_CHUNK_OVERLAP = 300 # tokens
|
|||
|
|
|||
|
|
|||
|
def chunking_by_chinese_splitter(
|
|||
|
content: str,
|
|||
|
overlap_token_size=128, # 重叠部分的token数量
|
|||
|
max_token_size=1024, # 每个chunk的最大token数量
|
|||
|
tiktoken_model="gpt-4o" # 使用的tokenizer模型名称
|
|||
|
):
|
|||
|
"""将长文本按照token数量切分成多个重叠的chunk
|
|||
|
|
|||
|
Args:
|
|||
|
content (str): 需要切分的原始文本内容
|
|||
|
overlap_token_size (int, optional): 相邻chunk之间的重叠token数. Defaults to 128.
|
|||
|
max_token_size (int, optional): 每个chunk的最大token数. Defaults to 1024.
|
|||
|
tiktoken_model (str, optional): 使用的tokenizer模型. Defaults to "gpt-4o".
|
|||
|
|
|||
|
Returns:
|
|||
|
list[dict]: 包含切分后的chunk列表,每个chunk包含以下字段:
|
|||
|
- tokens: chunk实际包含的token数
|
|||
|
- content: chunk的文本内容
|
|||
|
- chunk_order_index: chunk的序号
|
|||
|
"""
|
|||
|
text_splitter = ChineseRecursiveTextSplitter(keep_separator=True, is_separator_regex=True, chunk_size=max_token_size, chunk_overlap=overlap_token_size)
|
|||
|
# 切分文本
|
|||
|
chunks = text_splitter.split_text(content)
|
|||
|
# 构建返回结果
|
|||
|
result = []
|
|||
|
for i, chunk in enumerate(chunks):
|
|||
|
result.append({
|
|||
|
"tokens": len(chunk),
|
|||
|
"content": chunk,
|
|||
|
"chunk_order_index": i
|
|||
|
})
|
|||
|
|
|||
|
return result
|
|||
|
|
|||
|
|
|||
|
def _split_text_with_regex_from_end(
|
|||
|
text: str, separator: str, keep_separator: bool
|
|||
|
) -> List[str]:
|
|||
|
if separator:
|
|||
|
if keep_separator:
|
|||
|
# 模式中的括号会保留结果中的分隔符。
|
|||
|
_splits = re.split(f"({separator})", text)
|
|||
|
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
|
|||
|
if len(_splits) % 2 == 1:
|
|||
|
splits += _splits[-1:]
|
|||
|
else:
|
|||
|
splits = re.split(separator, text)
|
|||
|
else:
|
|||
|
splits = list(text)
|
|||
|
return [s for s in splits if s != ""]
|
|||
|
|
|||
|
|
|||
|
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
separators: Optional[List[str]] = None,
|
|||
|
keep_separator: bool = True,
|
|||
|
is_separator_regex: bool = True,
|
|||
|
**kwargs: Any,
|
|||
|
) -> None:
|
|||
|
super().__init__(keep_separator=keep_separator, **kwargs)
|
|||
|
self._separators = separators or [
|
|||
|
r"\n\n",
|
|||
|
r"\n",
|
|||
|
r"。|!|?",
|
|||
|
r"\.\s|\!\s|\?\s",
|
|||
|
r";|;\s",
|
|||
|
r",|,\s",
|
|||
|
]
|
|||
|
self._is_separator_regex = is_separator_regex
|
|||
|
|
|||
|
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
|||
|
"""拆分传入的文本并返回处理后的块。"""
|
|||
|
final_chunks = []
|
|||
|
# 获取适当的分隔符以使用
|
|||
|
separator = separators[-1]
|
|||
|
new_separators = []
|
|||
|
for i, _s in enumerate(separators):
|
|||
|
_separator = _s if self._is_separator_regex else re.escape(_s)
|
|||
|
if _s == "":
|
|||
|
separator = _s
|
|||
|
break
|
|||
|
if re.search(_separator, text):
|
|||
|
separator = _s
|
|||
|
new_separators = separators[i + 1:]
|
|||
|
break
|
|||
|
|
|||
|
_separator = separator if self._is_separator_regex else re.escape(separator)
|
|||
|
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
|
|||
|
|
|||
|
_good_splits = []
|
|||
|
_separator = "" if self._keep_separator else separator
|
|||
|
for s in splits:
|
|||
|
if self._length_function(s) < self._chunk_size:
|
|||
|
_good_splits.append(s)
|
|||
|
else:
|
|||
|
if _good_splits:
|
|||
|
merged_text = self._merge_splits(_good_splits, _separator)
|
|||
|
final_chunks.extend(merged_text)
|
|||
|
_good_splits = []
|
|||
|
if not new_separators:
|
|||
|
final_chunks.append(s)
|
|||
|
else:
|
|||
|
other_info = self._split_text(s, new_separators)
|
|||
|
final_chunks.extend(other_info)
|
|||
|
if _good_splits:
|
|||
|
merged_text = self._merge_splits(_good_splits, _separator)
|
|||
|
final_chunks.extend(merged_text)
|
|||
|
return [
|
|||
|
re.sub(r"\n{2,}", "\n", chunk.strip())
|
|||
|
for chunk in final_chunks
|
|||
|
if chunk.strip() != ""
|
|||
|
]
|