From 3970233c36b9c3eb14c3cc278c556dd923df2252 Mon Sep 17 00:00:00 2001 From: many2many <6168830@qq.com> Date: Sat, 16 Nov 2024 11:59:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8cursor=E5=8A=A0=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/base.py | 20 +++- lightrag/kg/neo4j_impl.py | 212 +++++++++++++++++++++++++++++--------- lightrag/lightrag.py | 22 +++- lightrag/llm.py | 68 +++++++----- lightrag/operate.py | 59 ++++++----- lightrag/storage.py | 43 +++++--- lightrag/utils.py | 40 ++++--- 7 files changed, 331 insertions(+), 133 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index cac1e79..830c4b3 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,9 +1,23 @@ -from dataclasses import dataclass, field -from typing import TypedDict, Union, Literal, Generic, TypeVar +# 从dataclasses模块导入数据类相关工具 +from dataclasses import ( + dataclass, # 数据类装饰器,用于简化类的定义 + field # 字段函数,用于定义特殊的字段属性 +) +# 从typing模块导入类型提示工具 +from typing import ( + TypedDict, # 类型化字典,用于定义具有特定类型的字典 + Union, # 联合类型,表示多个可能的类型之一 + Literal, # 字面量类型,用于限定特定的值 + Generic, # 泛型基类,用于创建泛型类 + TypeVar # 类型变量,用于泛型编程 +) + +# 导入numpy用于数值计算 import numpy as np -from .utils import EmbeddingFunc +# 从本地utils模块导入嵌入函数类 +from .utils import EmbeddingFunc # 用于处理文本嵌入的函数类 # 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号 TextChunkSchema = TypedDict( diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index e6b33a9..dbda2a8 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,103 +1,181 @@ -import asyncio -import os -from dataclasses import dataclass -from typing import Any, Union, Tuple, List, Dict -import inspect -from lightrag.utils import logger -from ..base import BaseGraphStorage -from neo4j import ( - AsyncGraphDatabase, - exceptions as neo4jExceptions, - AsyncDriver, - AsyncManagedTransaction, +# 标准库导入 +import asyncio # 异步IO支持 +import os # 操作系统接口,用于环境变量访问 +import inspect # 用于运行时检查Python对象 + +# 数据类和类型提示相关导入 +from dataclasses import dataclass # 数据类装饰器 +from typing import ( + Any, # 任意类型 + Union, # 联合类型 + Tuple, # 元组类型 + List, # 列表类型 + Dict # 字典类型 ) +# 本地模块导入 +from lightrag.utils import logger # 日志记录器 +from ..base import BaseGraphStorage # 图存储基类 +# Neo4j相关导入 +from neo4j import ( + AsyncGraphDatabase, # Neo4j异步图数据库驱动 + exceptions as neo4jExceptions, # Neo4j异常类 + AsyncDriver, # Neo4j异步驱动接口 + AsyncManagedTransaction, # Neo4j异步事务管理 +) + +# 重试机制相关导入 from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, + retry, # 重试装饰器 + stop_after_attempt, # 最大重试次数限制 + wait_exponential, # 指数退避等待策略 + retry_if_exception_type, # 基于异常类型的重试条件 ) @dataclass class Neo4JStorage(BaseGraphStorage): + """Neo4j图数据库存储实现类""" + @staticmethod def load_nx_graph(file_name): + """加载NetworkX图的静态方法(生产环境中未使用) + + Args: + file_name: 图文件名 + """ print("no preloading of graph with neo4j in production") def __init__(self, namespace, global_config): + """初始化Neo4j存储实例 + + Args: + namespace: 命名空间 + global_config: 全局配置 + + Note: + 从环境变量中读取Neo4j连接信息并初始化驱动 + """ + # 调用父类初始化 super().__init__(namespace=namespace, global_config=global_config) + # 初始化驱动相关属性 self._driver = None - self._driver_lock = asyncio.Lock() + self._driver_lock = asyncio.Lock() # 异步锁,用于并发控制 + + # 从环境变量获取Neo4j连接信息 URI = os.environ["NEO4J_URI"] USERNAME = os.environ["NEO4J_USERNAME"] PASSWORD = os.environ["NEO4J_PASSWORD"] + + # 初始化Neo4j异步驱动 self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) return None def __post_init__(self): + """数据类后初始化方法,设置节点嵌入算法""" self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } async def close(self): + """关闭数据库连接""" if self._driver: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): + """异步上下文管理器的退出方法""" if self._driver: await self._driver.close() async def index_done_callback(self): + """索引完成回调方法""" print("KG successfully indexed.") async def has_node(self, node_id: str) -> bool: + """检查节点是否存在 + + Args: + node_id: 节点ID + + Returns: + bool: 节点是否存在 + """ + # 清理节点ID中的引号 entity_name_label = node_id.strip('"') async with self._driver.session() as session: + # 构建Cypher查询 query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) + # 执行查询 result = await session.run(query) single_result = await result.single() + # 记录调试日志 logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' ) return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """检查边是否存在 + + Args: + source_node_id: 源节点ID + target_node_id: 目标节点ID + + Returns: + bool: 边是否存在 + """ + # 清理节点ID中的引号 entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') async with self._driver.session() as session: + # 构建Cypher查询 query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" ) + # 执行查询 result = await session.run(query) single_result = await result.single() + # 记录调试日志 logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' ) return single_result["edgeExists"] def close(self): + """同步关闭方法(注意:这是一个缩进错误,应该与其他方法对齐)""" self._driver.close() async def get_node(self, node_id: str) -> Union[dict, None]: + """获取节点信息 + + Args: + node_id: 节点ID + + Returns: + dict: 节点属性字典,如果节点不存在则返回None + """ async with self._driver.session() as session: + # 清理节点ID中的引号 entity_name_label = node_id.strip('"') + # 构建Cypher查询 query = f"MATCH (n:`{entity_name_label}`) RETURN n" + # 执行查询 result = await session.run(query) record = await result.single() if record: + # 提取节点数据并转换为字典 node = record["n"] node_dict = dict(node) + # 记录调试日志 logger.debug( f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" ) @@ -105,9 +183,19 @@ class Neo4JStorage(BaseGraphStorage): return None async def node_degree(self, node_id: str) -> int: + """获取节点的度(与节点相连的边的数量) + + Args: + node_id: 节点ID + + Returns: + int: 节点的度,如果节点不存在则返回None + """ + # 清理节点ID中的引号 entity_name_label = node_id.strip('"') async with self._driver.session() as session: + # 构建Cypher查询,计算节点的总边数 query = f""" MATCH (n:`{entity_name_label}`) RETURN COUNT{{ (n)--() }} AS totalEdgeCount @@ -116,6 +204,7 @@ class Neo4JStorage(BaseGraphStorage): record = await result.single() if record: edge_count = record["totalEdgeCount"] + # 记录调试日志 logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" ) @@ -124,15 +213,28 @@ class Neo4JStorage(BaseGraphStorage): return None async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """计算边的度(源节点和目标节点的度之和) + + Args: + src_id: 源节点ID + tgt_id: 目标节点ID + + Returns: + int: 边的度(两个节点的度之和) + """ + # 清理节点ID中的引号 entity_name_label_source = src_id.strip('"') entity_name_label_target = tgt_id.strip('"') + + # 获取源节点和目标节点的度 src_degree = await self.node_degree(entity_name_label_source) trg_degree = await self.node_degree(entity_name_label_target) - # Convert None to 0 for addition + # 将None转换为0以进行加法运算 src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree + # 计算总度数并记录日志 degrees = int(src_degree) + int(trg_degree) logger.debug( f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}" @@ -142,19 +244,21 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: + """获取两个节点之间的边的属性 + + Args: + source_node_id: 源节点ID + target_node_id: 目标节点ID + + Returns: + dict: 边的属性字典,如果边不存在则返回None + """ + # 清理节点ID中的引号 entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - """ - Find all edges between nodes of two given labels - Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ async with self._driver.session() as session: + # 构建Cypher查询,获取边的属性 query = f""" MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties @@ -164,9 +268,11 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_target=entity_name_label_target, ) + # 执行查询 result = await session.run(query) record = await result.single() if record: + # 转换结果为字典并记录日志 result = dict(record["edge_properties"]) logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" @@ -176,40 +282,49 @@ class Neo4JStorage(BaseGraphStorage): return None async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + """获取指定节点的所有边 + + Args: + source_node_id: 源节点ID + + Returns: + List[Tuple[str, str]]: 边列表,每个元素为(源节点标签, 目标节点标签)的元组 + """ node_label = source_node_id.strip('"') - """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ + # 构建Cypher查询,获取节点及其所有关系 query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" async with self._driver.session() as session: results = await session.run(query) edges = [] + # 异步迭代处理查询结果 async for record in results: source_node = record["n"] connected_node = record["connected"] + # 获取源节点标签(取第一个标签) source_label = ( list(source_node.labels)[0] if source_node.labels else None ) + # 获取目标节点标签(取第一个标签) target_label = ( list(connected_node.labels)[0] if connected_node and connected_node.labels else None ) + # 如果源节点和目标节点都有标签,则添加到边列表 if source_label and target_label: edges.append((source_label, target_label)) return edges @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( + stop=stop_after_attempt(3), # 最多重试3次 + wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待 + retry=retry_if_exception_type( # 指定需要重试的异常类型 ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, @@ -218,17 +333,17 @@ class Neo4JStorage(BaseGraphStorage): ), ) async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): - """ - Upsert a node in the Neo4j database. - + """更新或插入节点 + Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties + node_id: 节点的唯一标识符(用作标签) + node_data: 节点属性字典 """ label = node_id.strip('"') properties = node_data async def _do_upsert(tx: AsyncManagedTransaction): + """执行节点更新/插入的内部函数""" query = f""" MERGE (n:`{label}`) SET n += $properties @@ -246,9 +361,9 @@ class Neo4JStorage(BaseGraphStorage): raise @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( + stop=stop_after_attempt(3), # 最多重试3次 + wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待 + retry=retry_if_exception_type( # 指定需要重试的异常类型 ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, @@ -259,19 +374,19 @@ class Neo4JStorage(BaseGraphStorage): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] ): - """ - Upsert an edge and its properties between two nodes identified by their labels. - + """更新或插入边及其属性 + Args: - source_node_id (str): Label of the source node (used as identifier) - target_node_id (str): Label of the target node (used as identifier) - edge_data (dict): Dictionary of properties to set on the edge + source_node_id: 源节点标签(用作标识符) + target_node_id: 目标节点标签(用作标识符) + edge_data: 边属性字典 """ source_node_label = source_node_id.strip('"') target_node_label = target_node_id.strip('"') edge_properties = edge_data async def _do_upsert_edge(tx: AsyncManagedTransaction): + """执行边更新/插入的内部函数""" query = f""" MATCH (source:`{source_node_label}`) WITH source @@ -293,4 +408,5 @@ class Neo4JStorage(BaseGraphStorage): raise async def _node2vec_embed(self): + """节点嵌入方法(未实际使用)""" print("Implemented but never called.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e7ccec2..e2c4631 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -3,12 +3,30 @@ LightRAG - 轻量级检索增强生成系统 该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建 """ +# 导入异步IO模块,用于处理异步编程 import asyncio + +# 导入操作系统接口模块,用于处理文件路径和环境变量 import os -from dataclasses import asdict, dataclass, field + +# 从dataclasses模块导入数据类相关工具 +from dataclasses import ( + asdict, # 将数据类实例转换为字典的函数 + dataclass, # 数据类装饰器 + field, # 用于定义数据类字段的函数 +) + +# 导入日期时间处理模块 from datetime import datetime + +# 从functools导入partial函数,用于创建偏函数 from functools import partial -from typing import Type, cast + +# 从typing模块导入类型提示工具 +from typing import ( + Type, # 用于类型注解中表示类型的类型 + cast, # 用于类型转换的函数 +) # 导入LLM相关功能 from .llm import ( diff --git a/lightrag/llm.py b/lightrag/llm.py index eb0a067..97b4476 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,35 +1,51 @@ -import os -import copy -from functools import lru_cache -import json -import aioboto3 -import aiohttp -import numpy as np -import ollama +# 标准库导入 +import os # 操作系统接口 +import copy # 深浅拷贝功能 +from functools import lru_cache # 最近最少使用缓存装饰器 +import json # JSON数据处理 +import base64 # Base64编解码 +import struct # 处理二进制数据结构 +# 第三方异步库 +import aioboto3 # AWS SDK的异步版本 +import aiohttp # 异步HTTP客户端/服务器 +import ollama # Ollama API客户端 + +# 数值计算和机器学习库 +import numpy as np # 数值计算库 +import torch # PyTorch深度学习框架 +from transformers import ( # Hugging Face转换器库 + AutoTokenizer, # 自动分词器 + AutoModelForCausalLM, # 自动因果语言模型 +) + +# OpenAI相关导入 from openai import ( - AsyncOpenAI, - APIConnectionError, - RateLimitError, - Timeout, - AsyncAzureOpenAI, + AsyncOpenAI, # OpenAI异步客户端 + APIConnectionError, # API连接错误 + RateLimitError, # 速率限制错误 + Timeout, # 超时错误 + AsyncAzureOpenAI, # Azure OpenAI异步客户端 ) -import base64 -import struct - +# 重试机制相关导入 from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, + retry, # 重试装饰器 + stop_after_attempt, # 最大重试次数 + wait_exponential, # 指数退避等待 + retry_if_exception_type, # 基于异常类型的重试条件 +) + +# 数据验证和类型提示 +from pydantic import BaseModel, Field # 数据验证模型 +from typing import List, Dict, Callable, Any # 类型提示 + +# 本地模块导入 +from .base import BaseKVStorage # 键值存储基类 +from .utils import ( + compute_args_hash, # 计算参数哈希值 + wrap_embedding_func_with_attrs, # 包装嵌入函数的装饰器 ) -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -from pydantic import BaseModel, Field -from typing import List, Dict, Callable, Any -from .base import BaseKVStorage -from .utils import compute_args_hash, wrap_embedding_func_with_attrs # 禁用并行化以避免tokenizers的并行化导致的问题 os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/lightrag/operate.py b/lightrag/operate.py index bc6e212..3e68bed 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,30 +1,43 @@ -import asyncio -import json -import re -from typing import Union -from collections import Counter, defaultdict -import warnings +# 标准库导入 +import asyncio # 异步IO支持 +import json # JSON数据处理 +import re # 正则表达式支持 +from typing import Union # 类型提示:联合类型 +from collections import ( + Counter, # 计数器集合类 + defaultdict # 带默认值的字典 +) +import warnings # 警告控制 + +# 从本地utils模块导入工具函数 from .utils import ( - logger, - clean_str, - compute_mdhash_id, - decode_tokens_by_tiktoken, - encode_string_by_tiktoken, - is_float_regex, - list_of_list_to_csv, - pack_user_ass_to_openai_messages, - split_string_by_multi_markers, - truncate_list_by_token_size, - process_combine_contexts, + logger, # 日志记录器 + clean_str, # 字符串清理函数 + compute_mdhash_id, # 计算MD5哈希ID + decode_tokens_by_tiktoken, # tiktoken解码函数 + encode_string_by_tiktoken, # tiktoken编码函数 + is_float_regex, # 浮点数检查函数 + list_of_list_to_csv, # 列表转CSV函数 + pack_user_ass_to_openai_messages, # OpenAI消息打包函数 + split_string_by_multi_markers, # 多标记字符串分割函数 + truncate_list_by_token_size, # 基于token大小截断列表 + process_combine_contexts, # 上下文合并处理函数 ) + +# 从本地base模块导入基础类 from .base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, - TextChunkSchema, - QueryParam, + BaseGraphStorage, # 图存储基类 + BaseKVStorage, # 键值存储基类 + BaseVectorStorage, # 向量存储基类 + TextChunkSchema, # 文本块模式定义 + QueryParam, # 查询参数类 +) + +# 从本地prompt模块导入提示相关常量 +from .prompt import ( + GRAPH_FIELD_SEP, # 图字段分隔符 + PROMPTS # 提示模板集合 ) -from .prompt import GRAPH_FIELD_SEP, PROMPTS def chunking_by_token_size( diff --git a/lightrag/storage.py b/lightrag/storage.py index ff828df..7ea353a 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -1,23 +1,34 @@ -import asyncio -import html -import os -from dataclasses import dataclass -from typing import Any, Union, cast -import networkx as nx -import numpy as np -from nano_vectordb import NanoVectorDB +# 标准库导入 +import asyncio # 异步IO支持 +import html # HTML实体编解码 +import os # 操作系统接口,用于文件和路径操作 -from .utils import ( - logger, - load_json, - write_json, - compute_mdhash_id, +# 数据类和类型提示相关导入 +from dataclasses import dataclass # 数据类装饰器 +from typing import ( + Any, # 任意类型 + Union, # 联合类型 + cast # 类型转换函数 ) +# 第三方库导入 +import networkx as nx # 图数据处理库 +import numpy as np # 数值计算库 +from nano_vectordb import NanoVectorDB # 向量数据库 + +# 从本地utils模块导入工具函数 +from .utils import ( + logger, # 日志记录器 + load_json, # JSON文件加载函数 + write_json, # JSON文件写入函数 + compute_mdhash_id, # 计算MD5哈希ID的函数 +) + +# 从本地base模块导入基础存储类 from .base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, + BaseGraphStorage, # 图存储基类 + BaseKVStorage, # 键值存储基类 + BaseVectorStorage, # 向量存储基类 ) @dataclass diff --git a/lightrag/utils.py b/lightrag/utils.py index 57013db..75b9043 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,19 +1,29 @@ -import asyncio -import html -import io -import csv -import json -import logging -import os -import re -from dataclasses import dataclass -from functools import wraps -from hashlib import md5 -from typing import Any, Union, List -import xml.etree.ElementTree as ET +# 标准库导入 - 异步和IO操作 +import asyncio # 异步IO支持 +import html # HTML实体编解码 +import io # 内存IO操作 +import csv # CSV文件处理 +import json # JSON数据处理 +import logging # 日志记录 +import os # 操作系统接口 +import re # 正则表达式 -import numpy as np -import tiktoken +# 标准库导入 - 数据结构和工具 +from dataclasses import dataclass # 数据类装饰器 +from functools import wraps # 装饰器工具 +from hashlib import md5 # MD5哈希算法 +from typing import ( # 类型提示 + Any, # 任意类型 + Union, # 联合类型 + List # 列表类型 +) + +# XML处理 +import xml.etree.ElementTree as ET # XML解析和处理 + +# 第三方库导入 +import numpy as np # 数值计算库 +import tiktoken # OpenAI的分词器 # 全局编码器变量 ENCODER = None