使用cursor加注释。

This commit is contained in:
many2many 2024-11-16 11:59:20 +08:00
parent 6969e4afc1
commit 3970233c36
7 changed files with 331 additions and 133 deletions

View File

@ -1,9 +1,23 @@
from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar
# 从dataclasses模块导入数据类相关工具
from dataclasses import (
dataclass, # 数据类装饰器,用于简化类的定义
field # 字段函数,用于定义特殊的字段属性
)
# 从typing模块导入类型提示工具
from typing import (
TypedDict, # 类型化字典,用于定义具有特定类型的字典
Union, # 联合类型,表示多个可能的类型之一
Literal, # 字面量类型,用于限定特定的值
Generic, # 泛型基类,用于创建泛型类
TypeVar # 类型变量,用于泛型编程
)
# 导入numpy用于数值计算
import numpy as np
from .utils import EmbeddingFunc
# 从本地utils模块导入嵌入函数类
from .utils import EmbeddingFunc # 用于处理文本嵌入的函数类
# 定义文本块的数据结构包含令牌数、内容、完整文档ID和块序号
TextChunkSchema = TypedDict(

View File

@ -1,103 +1,181 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict
import inspect
from lightrag.utils import logger
from ..base import BaseGraphStorage
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
# 标准库导入
import asyncio # 异步IO支持
import os # 操作系统接口,用于环境变量访问
import inspect # 用于运行时检查Python对象
# 数据类和类型提示相关导入
from dataclasses import dataclass # 数据类装饰器
from typing import (
Any, # 任意类型
Union, # 联合类型
Tuple, # 元组类型
List, # 列表类型
Dict # 字典类型
)
# 本地模块导入
from lightrag.utils import logger # 日志记录器
from ..base import BaseGraphStorage # 图存储基类
# Neo4j相关导入
from neo4j import (
AsyncGraphDatabase, # Neo4j异步图数据库驱动
exceptions as neo4jExceptions, # Neo4j异常类
AsyncDriver, # Neo4j异步驱动接口
AsyncManagedTransaction, # Neo4j异步事务管理
)
# 重试机制相关导入
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
retry, # 重试装饰器
stop_after_attempt, # 最大重试次数限制
wait_exponential, # 指数退避等待策略
retry_if_exception_type, # 基于异常类型的重试条件
)
@dataclass
class Neo4JStorage(BaseGraphStorage):
"""Neo4j图数据库存储实现类"""
@staticmethod
def load_nx_graph(file_name):
"""加载NetworkX图的静态方法生产环境中未使用
Args:
file_name: 图文件名
"""
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
"""初始化Neo4j存储实例
Args:
namespace: 命名空间
global_config: 全局配置
Note:
从环境变量中读取Neo4j连接信息并初始化驱动
"""
# 调用父类初始化
super().__init__(namespace=namespace, global_config=global_config)
# 初始化驱动相关属性
self._driver = None
self._driver_lock = asyncio.Lock()
self._driver_lock = asyncio.Lock() # 异步锁,用于并发控制
# 从环境变量获取Neo4j连接信息
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
# 初始化Neo4j异步驱动
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None
def __post_init__(self):
"""数据类后初始化方法,设置节点嵌入算法"""
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
"""关闭数据库连接"""
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
"""异步上下文管理器的退出方法"""
if self._driver:
await self._driver.close()
async def index_done_callback(self):
"""索引完成回调方法"""
print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
"""检查节点是否存在
Args:
node_id: 节点ID
Returns:
bool: 节点是否存在
"""
# 清理节点ID中的引号
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
# 执行查询
result = await session.run(query)
single_result = await result.single()
# 记录调试日志
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""检查边是否存在
Args:
source_node_id: 源节点ID
target_node_id: 目标节点ID
Returns:
bool: 边是否存在
"""
# 清理节点ID中的引号
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
# 执行查询
result = await session.run(query)
single_result = await result.single()
# 记录调试日志
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
return single_result["edgeExists"]
def close(self):
"""同步关闭方法(注意:这是一个缩进错误,应该与其他方法对齐)"""
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]:
"""获取节点信息
Args:
node_id: 节点ID
Returns:
dict: 节点属性字典如果节点不存在则返回None
"""
async with self._driver.session() as session:
# 清理节点ID中的引号
entity_name_label = node_id.strip('"')
# 构建Cypher查询
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
# 执行查询
result = await session.run(query)
record = await result.single()
if record:
# 提取节点数据并转换为字典
node = record["n"]
node_dict = dict(node)
# 记录调试日志
logger.debug(
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
)
@ -105,9 +183,19 @@ class Neo4JStorage(BaseGraphStorage):
return None
async def node_degree(self, node_id: str) -> int:
"""获取节点的度(与节点相连的边的数量)
Args:
node_id: 节点ID
Returns:
int: 节点的度如果节点不存在则返回None
"""
# 清理节点ID中的引号
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询计算节点的总边数
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
@ -116,6 +204,7 @@ class Neo4JStorage(BaseGraphStorage):
record = await result.single()
if record:
edge_count = record["totalEdgeCount"]
# 记录调试日志
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
)
@ -124,15 +213,28 @@ class Neo4JStorage(BaseGraphStorage):
return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""计算边的度(源节点和目标节点的度之和)
Args:
src_id: 源节点ID
tgt_id: 目标节点ID
Returns:
int: 边的度两个节点的度之和
"""
# 清理节点ID中的引号
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
# 获取源节点和目标节点的度
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
# 将None转换为0以进行加法运算
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
# 计算总度数并记录日志
degrees = int(src_degree) + int(trg_degree)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
@ -142,19 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
"""
Find all edges between nodes of two given labels
"""获取两个节点之间的边的属性
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
source_node_id: 源节点ID
target_node_id: 目标节点ID
Returns:
list: List of all relationships/edges found
dict: 边的属性字典如果边不存在则返回None
"""
# 清理节点ID中的引号
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询获取边的属性
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
@ -164,9 +268,11 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_target=entity_name_label_target,
)
# 执行查询
result = await session.run(query)
record = await result.single()
if record:
# 转换结果为字典并记录日志
result = dict(record["edge_properties"])
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
@ -176,40 +282,49 @@ class Neo4JStorage(BaseGraphStorage):
return None
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
"""获取指定节点的所有边
Args:
source_node_id: 源节点ID
Returns:
List[Tuple[str, str]]: 边列表每个元素为(源节点标签, 目标节点标签)的元组
"""
node_label = source_node_id.strip('"')
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
"""
# 构建Cypher查询获取节点及其所有关系
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
results = await session.run(query)
edges = []
# 异步迭代处理查询结果
async for record in results:
source_node = record["n"]
connected_node = record["connected"]
# 获取源节点标签(取第一个标签)
source_label = (
list(source_node.labels)[0] if source_node.labels else None
)
# 获取目标节点标签(取第一个标签)
target_label = (
list(connected_node.labels)[0]
if connected_node and connected_node.labels
else None
)
# 如果源节点和目标节点都有标签,则添加到边列表
if source_label and target_label:
edges.append((source_label, target_label))
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
stop=stop_after_attempt(3), # 最多重试3次
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
retry=retry_if_exception_type( # 指定需要重试的异常类型
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
@ -218,17 +333,17 @@ class Neo4JStorage(BaseGraphStorage):
),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
Upsert a node in the Neo4j database.
"""更新或插入节点
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
node_id: 节点的唯一标识符用作标签
node_data: 节点属性字典
"""
label = node_id.strip('"')
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
"""执行节点更新/插入的内部函数"""
query = f"""
MERGE (n:`{label}`)
SET n += $properties
@ -246,9 +361,9 @@ class Neo4JStorage(BaseGraphStorage):
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
stop=stop_after_attempt(3), # 最多重试3次
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
retry=retry_if_exception_type( # 指定需要重试的异常类型
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
@ -259,19 +374,19 @@ class Neo4JStorage(BaseGraphStorage):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""
Upsert an edge and its properties between two nodes identified by their labels.
"""更新或插入边及其属性
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
source_node_id: 源节点标签用作标识符
target_node_id: 目标节点标签用作标识符
edge_data: 边属性字典
"""
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
"""执行边更新/插入的内部函数"""
query = f"""
MATCH (source:`{source_node_label}`)
WITH source
@ -293,4 +408,5 @@ class Neo4JStorage(BaseGraphStorage):
raise
async def _node2vec_embed(self):
"""节点嵌入方法(未实际使用)"""
print("Implemented but never called.")

View File

@ -3,12 +3,30 @@ LightRAG - 轻量级检索增强生成系统
该模块实现了一个基于图的文档检索和问答系统支持文档的存储检索和知识图谱构建
"""
# 导入异步IO模块用于处理异步编程
import asyncio
# 导入操作系统接口模块,用于处理文件路径和环境变量
import os
from dataclasses import asdict, dataclass, field
# 从dataclasses模块导入数据类相关工具
from dataclasses import (
asdict, # 将数据类实例转换为字典的函数
dataclass, # 数据类装饰器
field, # 用于定义数据类字段的函数
)
# 导入日期时间处理模块
from datetime import datetime
# 从functools导入partial函数用于创建偏函数
from functools import partial
from typing import Type, cast
# 从typing模块导入类型提示工具
from typing import (
Type, # 用于类型注解中表示类型的类型
cast, # 用于类型转换的函数
)
# 导入LLM相关功能
from .llm import (

View File

@ -1,35 +1,51 @@
import os
import copy
from functools import lru_cache
import json
import aioboto3
import aiohttp
import numpy as np
import ollama
# 标准库导入
import os # 操作系统接口
import copy # 深浅拷贝功能
from functools import lru_cache # 最近最少使用缓存装饰器
import json # JSON数据处理
import base64 # Base64编解码
import struct # 处理二进制数据结构
# 第三方异步库
import aioboto3 # AWS SDK的异步版本
import aiohttp # 异步HTTP客户端/服务器
import ollama # Ollama API客户端
# 数值计算和机器学习库
import numpy as np # 数值计算库
import torch # PyTorch深度学习框架
from transformers import ( # Hugging Face转换器库
AutoTokenizer, # 自动分词器
AutoModelForCausalLM, # 自动因果语言模型
)
# OpenAI相关导入
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
AsyncOpenAI, # OpenAI异步客户端
APIConnectionError, # API连接错误
RateLimitError, # 速率限制错误
Timeout, # 超时错误
AsyncAzureOpenAI, # Azure OpenAI异步客户端
)
import base64
import struct
# 重试机制相关导入
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
retry, # 重试装饰器
stop_after_attempt, # 最大重试次数
wait_exponential, # 指数退避等待
retry_if_exception_type, # 基于异常类型的重试条件
)
# 数据验证和类型提示
from pydantic import BaseModel, Field # 数据验证模型
from typing import List, Dict, Callable, Any # 类型提示
# 本地模块导入
from .base import BaseKVStorage # 键值存储基类
from .utils import (
compute_args_hash, # 计算参数哈希值
wrap_embedding_func_with_attrs, # 包装嵌入函数的装饰器
)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
# 禁用并行化以避免tokenizers的并行化导致的问题
os.environ["TOKENIZERS_PARALLELISM"] = "false"

View File

@ -1,30 +1,43 @@
import asyncio
import json
import re
from typing import Union
from collections import Counter, defaultdict
import warnings
# 标准库导入
import asyncio # 异步IO支持
import json # JSON数据处理
import re # 正则表达式支持
from typing import Union # 类型提示:联合类型
from collections import (
Counter, # 计数器集合类
defaultdict # 带默认值的字典
)
import warnings # 警告控制
# 从本地utils模块导入工具函数
from .utils import (
logger,
clean_str,
compute_mdhash_id,
decode_tokens_by_tiktoken,
encode_string_by_tiktoken,
is_float_regex,
list_of_list_to_csv,
pack_user_ass_to_openai_messages,
split_string_by_multi_markers,
truncate_list_by_token_size,
process_combine_contexts,
logger, # 日志记录器
clean_str, # 字符串清理函数
compute_mdhash_id, # 计算MD5哈希ID
decode_tokens_by_tiktoken, # tiktoken解码函数
encode_string_by_tiktoken, # tiktoken编码函数
is_float_regex, # 浮点数检查函数
list_of_list_to_csv, # 列表转CSV函数
pack_user_ass_to_openai_messages, # OpenAI消息打包函数
split_string_by_multi_markers, # 多标记字符串分割函数
truncate_list_by_token_size, # 基于token大小截断列表
process_combine_contexts, # 上下文合并处理函数
)
# 从本地base模块导入基础类
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
TextChunkSchema,
QueryParam,
BaseGraphStorage, # 图存储基类
BaseKVStorage, # 键值存储基类
BaseVectorStorage, # 向量存储基类
TextChunkSchema, # 文本块模式定义
QueryParam, # 查询参数类
)
# 从本地prompt模块导入提示相关常量
from .prompt import (
GRAPH_FIELD_SEP, # 图字段分隔符
PROMPTS # 提示模板集合
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
def chunking_by_token_size(

View File

@ -1,23 +1,34 @@
import asyncio
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
# 标准库导入
import asyncio # 异步IO支持
import html # HTML实体编解码
import os # 操作系统接口,用于文件和路径操作
from .utils import (
logger,
load_json,
write_json,
compute_mdhash_id,
# 数据类和类型提示相关导入
from dataclasses import dataclass # 数据类装饰器
from typing import (
Any, # 任意类型
Union, # 联合类型
cast # 类型转换函数
)
# 第三方库导入
import networkx as nx # 图数据处理库
import numpy as np # 数值计算库
from nano_vectordb import NanoVectorDB # 向量数据库
# 从本地utils模块导入工具函数
from .utils import (
logger, # 日志记录器
load_json, # JSON文件加载函数
write_json, # JSON文件写入函数
compute_mdhash_id, # 计算MD5哈希ID的函数
)
# 从本地base模块导入基础存储类
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
BaseGraphStorage, # 图存储基类
BaseKVStorage, # 键值存储基类
BaseVectorStorage, # 向量存储基类
)
@dataclass

View File

@ -1,19 +1,29 @@
import asyncio
import html
import io
import csv
import json
import logging
import os
import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union, List
import xml.etree.ElementTree as ET
# 标准库导入 - 异步和IO操作
import asyncio # 异步IO支持
import html # HTML实体编解码
import io # 内存IO操作
import csv # CSV文件处理
import json # JSON数据处理
import logging # 日志记录
import os # 操作系统接口
import re # 正则表达式
import numpy as np
import tiktoken
# 标准库导入 - 数据结构和工具
from dataclasses import dataclass # 数据类装饰器
from functools import wraps # 装饰器工具
from hashlib import md5 # MD5哈希算法
from typing import ( # 类型提示
Any, # 任意类型
Union, # 联合类型
List # 列表类型
)
# XML处理
import xml.etree.ElementTree as ET # XML解析和处理
# 第三方库导入
import numpy as np # 数值计算库
import tiktoken # OpenAI的分词器
# 全局编码器变量
ENCODER = None