From 6969e4afc1fbe69ce9b22fd9e09b5fd36a6152a9 Mon Sep 17 00:00:00 2001 From: many2many <6168830@qq.com> Date: Sat, 16 Nov 2024 11:40:45 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8cursor=E5=8A=A0=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/base.py | 61 +++++++++-- lightrag/utils.py | 267 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 302 insertions(+), 26 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index bd47257..cac1e79 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -5,119 +5,162 @@ import numpy as np from .utils import EmbeddingFunc +# 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号 TextChunkSchema = TypedDict( "TextChunkSchema", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}, ) +# 定义泛型类型变量 T = TypeVar("T") @dataclass class QueryParam: + """查询参数配置类 + + 属性: + mode: 查询模式,可选 'local'(局部)、'global'(全局)、'hybrid'(混合)或'naive'(朴素) + only_need_context: 是否仅需要上下文 + response_type: 响应类型 + top_k: 检索的top-k项数量 + max_token_for_text_unit: 原始文本块的最大令牌数 + max_token_for_global_context: 关系描述的最大令牌数 + max_token_for_local_context: 实体描述的最大令牌数 + """ mode: Literal["local", "global", "hybrid", "naive"] = "global" only_need_context: bool = False response_type: str = "Multiple Paragraphs" - # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. top_k: int = 60 - # Number of tokens for the original chunks. max_token_for_text_unit: int = 4000 - # Number of tokens for the relationship descriptions max_token_for_global_context: int = 4000 - # Number of tokens for the entity descriptions max_token_for_local_context: int = 4000 @dataclass class StorageNameSpace: + """存储命名空间基类 + + 属性: + namespace: 命名空间 + global_config: 全局配置字典 + """ namespace: str global_config: dict async def index_done_callback(self): - """commit the storage operations after indexing""" + """索引完成后的回调函数,用于提交存储操作""" pass async def query_done_callback(self): - """commit the storage operations after querying""" + """查询完成后的回调函数,用于提交存储操作""" pass @dataclass class BaseVectorStorage(StorageNameSpace): + """向量存储基类 + + 属性: + embedding_func: 嵌入函数 + meta_fields: 元数据字段集合 + """ embedding_func: EmbeddingFunc meta_fields: set = field(default_factory=set) async def query(self, query: str, top_k: int) -> list[dict]: + """查询接口""" raise NotImplementedError async def upsert(self, data: dict[str, dict]): - """Use 'content' field from value for embedding, use key as id. - If embedding_func is None, use 'embedding' field from value + """更新或插入数据 + 使用value中的'content'字段进行嵌入,使用key作为ID + 如果embedding_func为None,则使用value中的'embedding'字段 """ raise NotImplementedError @dataclass class BaseKVStorage(Generic[T], StorageNameSpace): + """键值存储基类""" + async def all_keys(self) -> list[str]: + """获取所有键""" raise NotImplementedError async def get_by_id(self, id: str) -> Union[T, None]: + """通过ID获取值""" raise NotImplementedError async def get_by_ids( self, ids: list[str], fields: Union[set[str], None] = None ) -> list[Union[T, None]]: + """通过ID列表批量获取值""" raise NotImplementedError async def filter_keys(self, data: list[str]) -> set[str]: - """return un-exist keys""" + """返回不存在的键集合""" raise NotImplementedError async def upsert(self, data: dict[str, T]): + """更新或插入数据""" raise NotImplementedError async def drop(self): + """删除存储""" raise NotImplementedError @dataclass class BaseGraphStorage(StorageNameSpace): + """图存储基类""" + async def has_node(self, node_id: str) -> bool: + """检查节点是否存在""" raise NotImplementedError async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """检查边是否存在""" raise NotImplementedError async def node_degree(self, node_id: str) -> int: + """获取节点的度""" raise NotImplementedError async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """获取边的度""" raise NotImplementedError async def get_node(self, node_id: str) -> Union[dict, None]: + """获取节点信息""" raise NotImplementedError async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: + """获取边信息""" raise NotImplementedError async def get_node_edges( self, source_node_id: str ) -> Union[list[tuple[str, str]], None]: + """获取节点的所有边""" raise NotImplementedError async def upsert_node(self, node_id: str, node_data: dict[str, str]): + """更新或插入节点""" raise NotImplementedError async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): + """更新或插入边""" raise NotImplementedError async def delete_node(self, node_id: str): + """删除节点""" raise NotImplementedError async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + """节点嵌入(在lightrag中未使用)""" raise NotImplementedError("Node embedding is not used in lightrag.") diff --git a/lightrag/utils.py b/lightrag/utils.py index 104c9fe..57013db 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -15,38 +15,65 @@ import xml.etree.ElementTree as ET import numpy as np import tiktoken +# 全局编码器变量 ENCODER = None +# 创建一个名为"lightrag"的日志记录器 logger = logging.getLogger("lightrag") def set_logger(log_file: str): + """设置日志记录器 + + Args: + log_file: 日志文件路径 + """ + # 设置日志级别为DEBUG logger.setLevel(logging.DEBUG) + # 创建文件处理器 file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) + # 设置日志格式:时间 - 名称 - 级别 - 消息 formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) file_handler.setFormatter(formatter) + # 如果logger没有处理器,则添加处理器 if not logger.handlers: logger.addHandler(file_handler) @dataclass class EmbeddingFunc: + """嵌入函数的包装类 + + Attributes: + embedding_dim: 嵌入向量的维度 + max_token_size: 最大token数量 + func: 实际的嵌入函数 + """ embedding_dim: int max_token_size: int func: callable async def __call__(self, *args, **kwargs) -> np.ndarray: + """使类实例可调用,直接调用内部的嵌入函数""" return await self.func(*args, **kwargs) def locate_json_string_body_from_string(content: str) -> Union[str, None]: - """Locate the JSON string body from a string""" + """从字符串中定位JSON字符串主体 + + Args: + content: 输入字符串 + + Returns: + 找到的JSON字符串或None + """ + # 使用正则表达式查找{}包围的内容,DOTALL模式允许匹配跨行 maybe_json_str = re.search(r"{.*}", content, re.DOTALL) if maybe_json_str is not None: return maybe_json_str.group(0) @@ -55,6 +82,18 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: def convert_response_to_json(response: str) -> dict: + """将响应字符串转换为JSON对象 + + Args: + response: 响应字符串 + + Returns: + 解析后的JSON字典 + + Raises: + AssertionError: 无法从响应中解析JSON + JSONDecodeError: JSON解析失败 + """ json_str = locate_json_string_body_from_string(response) assert json_str is not None, f"Unable to parse JSON from response: {response}" try: @@ -66,23 +105,48 @@ def convert_response_to_json(response: str) -> dict: def compute_args_hash(*args): + """计算参数的MD5哈希值 + + Args: + *args: 任意数量的参数 + + Returns: + 参数的MD5哈希值的十六进制字符串 + """ return md5(str(args).encode()).hexdigest() def compute_mdhash_id(content, prefix: str = ""): + """计算内容的MD5哈希ID + + Args: + content: 要计算哈希的内容 + prefix: 哈希值的前缀(默认为空) + + Returns: + 带前缀的MD5哈希值 + """ return prefix + md5(content.encode()).hexdigest() def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): - """Add restriction of maximum async calling times for a async func""" - + """限制异步函数的最大并发调用次数的装饰器 + + Args: + max_size: 最大并发数 + waitting_time: 等待时间间隔(秒) + + Returns: + 装饰器函数 + """ def final_decro(func): - """Not using async.Semaphore to aovid use nest-asyncio""" + """内部装饰器,不使用asyncio.Semaphore以避免使用nest-asyncio""" __current_size = 0 @wraps(func) async def wait_func(*args, **kwargs): nonlocal __current_size + # 当当前并发数达到最大值时,等待 while __current_size >= max_size: await asyncio.sleep(waitting_time) __current_size += 1 @@ -96,8 +160,14 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): def wrap_embedding_func_with_attrs(**kwargs): - """Wrap a function with attributes""" - + """使用属性包装嵌入函数的装饰器 + + Args: + **kwargs: 传递给EmbeddingFunc的关键字参数 + + Returns: + 返回一个EmbeddingFunc实例的装饰器 + """ def final_decro(func) -> EmbeddingFunc: new_func = EmbeddingFunc(**kwargs, func=func) return new_func @@ -106,6 +176,14 @@ def wrap_embedding_func_with_attrs(**kwargs): def load_json(file_name): + """从文件加载JSON数据 + + Args: + file_name: JSON文件路径 + + Returns: + 加载的JSON对象,如果文件不存在则返回None + """ if not os.path.exists(file_name): return None with open(file_name, encoding="utf-8") as f: @@ -113,11 +191,29 @@ def load_json(file_name): def write_json(json_obj, file_name): + """将JSON对象写入文件 + + Args: + json_obj: 要写入的JSON对象 + file_name: 目标文件路径 + + Note: + 使用indent=2进行格式化,ensure_ascii=False支持中文 + """ with open(file_name, "w", encoding="utf-8") as f: json.dump(json_obj, f, indent=2, ensure_ascii=False) def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): + """使用tiktoken将字符串编码为tokens + + Args: + content: 要编码的字符串 + model_name: 使用的模型名称 + + Returns: + 编码后的token列表 + """ global ENCODER if ENCODER is None: ENCODER = tiktoken.encoding_for_model(model_name) @@ -126,6 +222,15 @@ def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"): + """使用tiktoken将tokens解码为字符串 + + Args: + tokens: token列表 + model_name: 使用的模型名称 + + Returns: + 解码后的字符串 + """ global ENCODER if ENCODER is None: ENCODER = tiktoken.encoding_for_model(model_name) @@ -134,6 +239,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"): def pack_user_ass_to_openai_messages(*args: str): + """将用户和助手的对话打包成OpenAI消息格式 + + Args: + *args: 交替的用户和助手消息 + + Returns: + OpenAI格式的消息列表,奇数位为用户消息,偶数位为助手消息 + """ roles = ["user", "assistant"] return [ {"role": roles[i % 2], "content": content} for i, content in enumerate(args) @@ -141,7 +254,15 @@ def pack_user_ass_to_openai_messages(*args: str): def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: - """Split a string by multiple markers""" + """使用多个标记分割字符串 + + Args: + content: 要分割的字符串 + markers: 分割标记列表 + + Returns: + 分割后的字符串列表,去除空白 + """ if not markers: return [content] results = re.split("|".join(re.escape(marker) for marker in markers), content) @@ -151,22 +272,44 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str] # Refer the utils functions of the official GraphRAG implementation: # https://github.com/microsoft/graphrag def clean_str(input: Any) -> str: - """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" - # If we get non-string input, just give it back + """清理字符串中的HTML转义字符和控制字符 + + Args: + input: 输入字符串或其他类型 + + Returns: + 清理后的字符串,如果输入不是字符串则原样返回 + """ if not isinstance(input, str): return input result = html.unescape(input.strip()) - # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) def is_float_regex(value): + """检查字符串是否为浮点数格式 + + Args: + value: 要检查的值 + + Returns: + 是否为浮点数格式的布尔值 + """ return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): - """Truncate a list of data by token size""" + """根据token大小截断列表 + + Args: + list_data: 要截断的列表 + key: 从列表项中提取文本的函数 + max_token_size: 最大token数量 + + Returns: + 截断后的列表 + """ if max_token_size <= 0: return [] tokens = 0 @@ -178,6 +321,14 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: def list_of_list_to_csv(data: List[List[str]]) -> str: + """将二维列表转换为CSV字符串 + + Args: + data: 二维字符串列表 + + Returns: + CSV格式的字符串 + """ output = io.StringIO() writer = csv.writer(output) writer.writerows(data) @@ -185,65 +336,123 @@ def list_of_list_to_csv(data: List[List[str]]) -> str: def csv_string_to_list(csv_string: str) -> List[List[str]]: + """将CSV字符串转换为二维列表 + + Args: + csv_string: CSV格式的字符串 + + Returns: + 二维字符串列表 + """ output = io.StringIO(csv_string) reader = csv.reader(output) return [row for row in reader] def save_data_to_file(data, file_name): + """将数据保存为JSON文件 + + Args: + data: 要保存的数据 + file_name: 目标文件路径 + """ with open(file_name, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) def xml_to_json(xml_file): + """将GraphML格式的XML文件转换为JSON格式 + + Args: + xml_file: GraphML文件路径 + + Returns: + 包含节点和边信息的字典,解析失败时返回None + + Note: + 转换后的数据结构为: + { + "nodes": [ + { + "id": str, + "entity_type": str, + "description": str, + "source_id": str + } + ], + "edges": [ + { + "source": str, + "target": str, + "weight": float, + "description": str, + "keywords": str, + "source_id": str + } + ] + } + """ try: + # 解析XML文件 tree = ET.parse(xml_file) root = tree.getroot() - # Print the root element's tag and attributes to confirm the file has been correctly loaded + # 打印根元素信息以确认文件正确加载 print(f"Root element: {root.tag}") print(f"Root attributes: {root.attrib}") + # 初始化数据结构 data = {"nodes": [], "edges": []} - # Use namespace + # 设置GraphML的命名空间 namespace = {"": "http://graphml.graphdrawing.org/xmlns"} + # 处理所有节点 for node in root.findall(".//node", namespace): node_data = { + # 获取节点ID并去除引号 "id": node.get("id").strip('"'), + # 获取实体类型,如果不存在则为空字符串 "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "", + # 获取描述信息 "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "", + # 获取源ID "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else "", } data["nodes"].append(node_data) + # 处理所有边 for edge in root.findall(".//edge", namespace): edge_data = { + # 获取边的源节点和目标节点 "source": edge.get("source").strip('"'), "target": edge.get("target").strip('"'), + # 获取权重,默认为0.0 "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0, + # 获取描述信息 "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "", + # 获取关键词 "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "", + # 获取源ID "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else "", } data["edges"].append(edge_data) - # Print the number of nodes and edges found + # 打印统计信息 print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") return data @@ -256,31 +465,55 @@ def xml_to_json(xml_file): def process_combine_contexts(hl, ll): + """合并高层和低层上下文信息 + + Args: + hl: 高层上下文的CSV字符串 + ll: 低层上下文的CSV字符串 + + Returns: + 合并后的CSV格式字符串 + + Note: + 处理步骤: + 1. 解析输入的CSV字符串 + 2. 提取并保留表头 + 3. 合并数据行并去重 + 4. 重新格式化为CSV字符串 + """ + # 初始化表头 header = None + # 解析CSV字符串 list_hl = csv_string_to_list(hl.strip()) list_ll = csv_string_to_list(ll.strip()) + # 提取表头 if list_hl: header = list_hl[0] - list_hl = list_hl[1:] + list_hl = list_hl[1:] # 移除表头行 if list_ll: header = list_ll[0] - list_ll = list_ll[1:] + list_ll = list_ll[1:] # 移除表头行 if header is None: return "" + # 处理数据行,只保留除第一列外的数据 if list_hl: list_hl = [",".join(item[1:]) for item in list_hl if item] if list_ll: list_ll = [",".join(item[1:]) for item in list_ll if item] + # 合并数据并去重 combined_sources_set = set(filter(None, list_hl + list_ll)) - combined_sources = [",\t".join(header)] + # 重新构建CSV字符串 + combined_sources = [",\t".join(header)] # 添加表头 + # 添加数据行,并加上新的序号 for i, item in enumerate(combined_sources_set, start=1): combined_sources.append(f"{i},\t{item}") + # 用换行符连接所有行 combined_sources = "\n".join(combined_sources) return combined_sources