diff --git a/lightrag/chinese.py b/lightrag/chinese.py new file mode 100644 index 0000000..e47171e --- /dev/null +++ b/lightrag/chinese.py @@ -0,0 +1,125 @@ +import re +from typing import Any, List, Optional + +from langchain_text_splitters import RecursiveCharacterTextSplitter +from typing import Any + +# CHUNK_SIZE是指在处理大型数据集时,将数据分成多个小块(chunk)时,每个小块的大小。这样做可以有效地管理内存使用,避免一次性加载过多数据导致内存溢出。在这里应根据大模型API请求的上下文 tokens 大小进行设置。 +# CHUNK_OVERLAP是指在处理文本数据时,将文本分成多个小块(chunk)时,相邻块之间重叠的部分。这样做可以确保在分块处理时不会丢失重要信息,特别是在进行文本分类、实体识别等任务时,有助于提高模型的准确性和连贯性。 +DEFAULT_CHUNK_SIZE = 2500 # tokens +DEFAULT_CHUNK_OVERLAP = 300 # tokens + + +def chunking_by_chinese_splitter( + content: str, + overlap_token_size=128, # 重叠部分的token数量 + max_token_size=1024, # 每个chunk的最大token数量 + tiktoken_model="gpt-4o" # 使用的tokenizer模型名称 +): + """将长文本按照token数量切分成多个重叠的chunk + + Args: + content (str): 需要切分的原始文本内容 + overlap_token_size (int, optional): 相邻chunk之间的重叠token数. Defaults to 128. + max_token_size (int, optional): 每个chunk的最大token数. Defaults to 1024. + tiktoken_model (str, optional): 使用的tokenizer模型. Defaults to "gpt-4o". + + Returns: + list[dict]: 包含切分后的chunk列表,每个chunk包含以下字段: + - tokens: chunk实际包含的token数 + - content: chunk的文本内容 + - chunk_order_index: chunk的序号 + """ + text_splitter = ChineseRecursiveTextSplitter(keep_separator=True, is_separator_regex=True, chunk_size=max_token_size, chunk_overlap=overlap_token_size) + # 切分文本 + chunks = text_splitter.split_text(content) + # 构建返回结果 + result = [] + for i, chunk in enumerate(chunks): + result.append({ + "tokens": len(chunk), + "content": chunk, + "chunk_order_index": i + }) + + return result + + +def _split_text_with_regex_from_end( + text: str, separator: str, keep_separator: bool +) -> List[str]: + if separator: + if keep_separator: + # 模式中的括号会保留结果中的分隔符。 + _splits = re.split(f"({separator})", text) + splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] + if len(_splits) % 2 == 1: + splits += _splits[-1:] + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or [ + r"\n\n", + r"\n", + r"。|!|?", + r"\.\s|\!\s|\?\s", + r";|;\s", + r",|,\s", + ] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """拆分传入的文本并返回处理后的块。""" + final_chunks = [] + # 获取适当的分隔符以使用 + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1:] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) + + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return [ + re.sub(r"\n{2,}", "\n", chunk.strip()) + for chunk in final_chunks + if chunk.strip() != "" + ]