lightrag-comments/lightrag/chinese.py
many2many e339bfab07 [ADD]根据中文标点符号进行trunk.
可将chunking_by_token_size的实现改为: return chunking_by_chinese_splitter(content, overlap_token_size, max_token_size, tiktoken_model)
2024-11-17 12:31:50 +08:00

126 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from typing import Any, List, Optional
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import Any
# CHUNK_SIZE是指在处理大型数据集时将数据分成多个小块chunk每个小块的大小。这样做可以有效地管理内存使用避免一次性加载过多数据导致内存溢出。在这里应根据大模型API请求的上下文 tokens 大小进行设置。
# CHUNK_OVERLAP是指在处理文本数据时将文本分成多个小块chunk相邻块之间重叠的部分。这样做可以确保在分块处理时不会丢失重要信息特别是在进行文本分类、实体识别等任务时有助于提高模型的准确性和连贯性。
DEFAULT_CHUNK_SIZE = 2500 # tokens
DEFAULT_CHUNK_OVERLAP = 300 # tokens
def chunking_by_chinese_splitter(
content: str,
overlap_token_size=128, # 重叠部分的token数量
max_token_size=1024, # 每个chunk的最大token数量
tiktoken_model="gpt-4o" # 使用的tokenizer模型名称
):
"""将长文本按照token数量切分成多个重叠的chunk
Args:
content (str): 需要切分的原始文本内容
overlap_token_size (int, optional): 相邻chunk之间的重叠token数. Defaults to 128.
max_token_size (int, optional): 每个chunk的最大token数. Defaults to 1024.
tiktoken_model (str, optional): 使用的tokenizer模型. Defaults to "gpt-4o".
Returns:
list[dict]: 包含切分后的chunk列表,每个chunk包含以下字段:
- tokens: chunk实际包含的token数
- content: chunk的文本内容
- chunk_order_index: chunk的序号
"""
text_splitter = ChineseRecursiveTextSplitter(keep_separator=True, is_separator_regex=True, chunk_size=max_token_size, chunk_overlap=overlap_token_size)
# 切分文本
chunks = text_splitter.split_text(content)
# 构建返回结果
result = []
for i, chunk in enumerate(chunks):
result.append({
"tokens": len(chunk),
"content": chunk,
"chunk_order_index": i
})
return result
def _split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool
) -> List[str]:
if separator:
if keep_separator:
# 模式中的括号会保留结果中的分隔符。
_splits = re.split(f"({separator})", text)
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
if len(_splits) % 2 == 1:
splits += _splits[-1:]
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [
r"\n\n",
r"\n",
r"。||",
r"\.\s|\!\s|\?\s",
r"|;\s",
r"|,\s",
]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""拆分传入的文本并返回处理后的块。"""
final_chunks = []
# 获取适当的分隔符以使用
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return [
re.sub(r"\n{2,}", "\n", chunk.strip())
for chunk in final_chunks
if chunk.strip() != ""
]