使用cursor加注释。
This commit is contained in:
parent
6969e4afc1
commit
3970233c36
@ -1,9 +1,23 @@
|
|||||||
from dataclasses import dataclass, field
|
# 从dataclasses模块导入数据类相关工具
|
||||||
from typing import TypedDict, Union, Literal, Generic, TypeVar
|
from dataclasses import (
|
||||||
|
dataclass, # 数据类装饰器,用于简化类的定义
|
||||||
|
field # 字段函数,用于定义特殊的字段属性
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从typing模块导入类型提示工具
|
||||||
|
from typing import (
|
||||||
|
TypedDict, # 类型化字典,用于定义具有特定类型的字典
|
||||||
|
Union, # 联合类型,表示多个可能的类型之一
|
||||||
|
Literal, # 字面量类型,用于限定特定的值
|
||||||
|
Generic, # 泛型基类,用于创建泛型类
|
||||||
|
TypeVar # 类型变量,用于泛型编程
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入numpy用于数值计算
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .utils import EmbeddingFunc
|
# 从本地utils模块导入嵌入函数类
|
||||||
|
from .utils import EmbeddingFunc # 用于处理文本嵌入的函数类
|
||||||
|
|
||||||
# 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号
|
# 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号
|
||||||
TextChunkSchema = TypedDict(
|
TextChunkSchema = TypedDict(
|
||||||
|
@ -1,103 +1,181 @@
|
|||||||
import asyncio
|
# 标准库导入
|
||||||
import os
|
import asyncio # 异步IO支持
|
||||||
from dataclasses import dataclass
|
import os # 操作系统接口,用于环境变量访问
|
||||||
from typing import Any, Union, Tuple, List, Dict
|
import inspect # 用于运行时检查Python对象
|
||||||
import inspect
|
|
||||||
from lightrag.utils import logger
|
# 数据类和类型提示相关导入
|
||||||
from ..base import BaseGraphStorage
|
from dataclasses import dataclass # 数据类装饰器
|
||||||
from neo4j import (
|
from typing import (
|
||||||
AsyncGraphDatabase,
|
Any, # 任意类型
|
||||||
exceptions as neo4jExceptions,
|
Union, # 联合类型
|
||||||
AsyncDriver,
|
Tuple, # 元组类型
|
||||||
AsyncManagedTransaction,
|
List, # 列表类型
|
||||||
|
Dict # 字典类型
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 本地模块导入
|
||||||
|
from lightrag.utils import logger # 日志记录器
|
||||||
|
from ..base import BaseGraphStorage # 图存储基类
|
||||||
|
|
||||||
|
# Neo4j相关导入
|
||||||
|
from neo4j import (
|
||||||
|
AsyncGraphDatabase, # Neo4j异步图数据库驱动
|
||||||
|
exceptions as neo4jExceptions, # Neo4j异常类
|
||||||
|
AsyncDriver, # Neo4j异步驱动接口
|
||||||
|
AsyncManagedTransaction, # Neo4j异步事务管理
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重试机制相关导入
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry, # 重试装饰器
|
||||||
stop_after_attempt,
|
stop_after_attempt, # 最大重试次数限制
|
||||||
wait_exponential,
|
wait_exponential, # 指数退避等待策略
|
||||||
retry_if_exception_type,
|
retry_if_exception_type, # 基于异常类型的重试条件
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
|
"""Neo4j图数据库存储实现类"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_nx_graph(file_name):
|
def load_nx_graph(file_name):
|
||||||
|
"""加载NetworkX图的静态方法(生产环境中未使用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: 图文件名
|
||||||
|
"""
|
||||||
print("no preloading of graph with neo4j in production")
|
print("no preloading of graph with neo4j in production")
|
||||||
|
|
||||||
def __init__(self, namespace, global_config):
|
def __init__(self, namespace, global_config):
|
||||||
|
"""初始化Neo4j存储实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
namespace: 命名空间
|
||||||
|
global_config: 全局配置
|
||||||
|
|
||||||
|
Note:
|
||||||
|
从环境变量中读取Neo4j连接信息并初始化驱动
|
||||||
|
"""
|
||||||
|
# 调用父类初始化
|
||||||
super().__init__(namespace=namespace, global_config=global_config)
|
super().__init__(namespace=namespace, global_config=global_config)
|
||||||
|
# 初始化驱动相关属性
|
||||||
self._driver = None
|
self._driver = None
|
||||||
self._driver_lock = asyncio.Lock()
|
self._driver_lock = asyncio.Lock() # 异步锁,用于并发控制
|
||||||
|
|
||||||
|
# 从环境变量获取Neo4j连接信息
|
||||||
URI = os.environ["NEO4J_URI"]
|
URI = os.environ["NEO4J_URI"]
|
||||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||||
|
|
||||||
|
# 初始化Neo4j异步驱动
|
||||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||||
URI, auth=(USERNAME, PASSWORD)
|
URI, auth=(USERNAME, PASSWORD)
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
"""数据类后初始化方法,设置节点嵌入算法"""
|
||||||
self._node_embed_algorithms = {
|
self._node_embed_algorithms = {
|
||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
"""关闭数据库连接"""
|
||||||
if self._driver:
|
if self._driver:
|
||||||
await self._driver.close()
|
await self._driver.close()
|
||||||
self._driver = None
|
self._driver = None
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
"""异步上下文管理器的退出方法"""
|
||||||
if self._driver:
|
if self._driver:
|
||||||
await self._driver.close()
|
await self._driver.close()
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
|
"""索引完成回调方法"""
|
||||||
print("KG successfully indexed.")
|
print("KG successfully indexed.")
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
"""检查节点是否存在
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: 节点ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 节点是否存在
|
||||||
|
"""
|
||||||
|
# 清理节点ID中的引号
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
|
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
|
# 构建Cypher查询
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||||
)
|
)
|
||||||
|
# 执行查询
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
single_result = await result.single()
|
single_result = await result.single()
|
||||||
|
# 记录调试日志
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
||||||
)
|
)
|
||||||
return single_result["node_exists"]
|
return single_result["node_exists"]
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
|
"""检查边是否存在
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_node_id: 源节点ID
|
||||||
|
target_node_id: 目标节点ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 边是否存在
|
||||||
|
"""
|
||||||
|
# 清理节点ID中的引号
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
|
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
|
# 构建Cypher查询
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
||||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||||
)
|
)
|
||||||
|
# 执行查询
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
single_result = await result.single()
|
single_result = await result.single()
|
||||||
|
# 记录调试日志
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
||||||
)
|
)
|
||||||
return single_result["edgeExists"]
|
return single_result["edgeExists"]
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
"""同步关闭方法(注意:这是一个缩进错误,应该与其他方法对齐)"""
|
||||||
self._driver.close()
|
self._driver.close()
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
|
"""获取节点信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: 节点ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 节点属性字典,如果节点不存在则返回None
|
||||||
|
"""
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
|
# 清理节点ID中的引号
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
|
# 构建Cypher查询
|
||||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||||
|
# 执行查询
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
if record:
|
if record:
|
||||||
|
# 提取节点数据并转换为字典
|
||||||
node = record["n"]
|
node = record["n"]
|
||||||
node_dict = dict(node)
|
node_dict = dict(node)
|
||||||
|
# 记录调试日志
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
|
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
|
||||||
)
|
)
|
||||||
@ -105,9 +183,19 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
|
"""获取节点的度(与节点相连的边的数量)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: 节点ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 节点的度,如果节点不存在则返回None
|
||||||
|
"""
|
||||||
|
# 清理节点ID中的引号
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
|
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
|
# 构建Cypher查询,计算节点的总边数
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:`{entity_name_label}`)
|
MATCH (n:`{entity_name_label}`)
|
||||||
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
||||||
@ -116,6 +204,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
record = await result.single()
|
record = await result.single()
|
||||||
if record:
|
if record:
|
||||||
edge_count = record["totalEdgeCount"]
|
edge_count = record["totalEdgeCount"]
|
||||||
|
# 记录调试日志
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
|
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
|
||||||
)
|
)
|
||||||
@ -124,15 +213,28 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
|
"""计算边的度(源节点和目标节点的度之和)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_id: 源节点ID
|
||||||
|
tgt_id: 目标节点ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 边的度(两个节点的度之和)
|
||||||
|
"""
|
||||||
|
# 清理节点ID中的引号
|
||||||
entity_name_label_source = src_id.strip('"')
|
entity_name_label_source = src_id.strip('"')
|
||||||
entity_name_label_target = tgt_id.strip('"')
|
entity_name_label_target = tgt_id.strip('"')
|
||||||
|
|
||||||
|
# 获取源节点和目标节点的度
|
||||||
src_degree = await self.node_degree(entity_name_label_source)
|
src_degree = await self.node_degree(entity_name_label_source)
|
||||||
trg_degree = await self.node_degree(entity_name_label_target)
|
trg_degree = await self.node_degree(entity_name_label_target)
|
||||||
|
|
||||||
# Convert None to 0 for addition
|
# 将None转换为0以进行加法运算
|
||||||
src_degree = 0 if src_degree is None else src_degree
|
src_degree = 0 if src_degree is None else src_degree
|
||||||
trg_degree = 0 if trg_degree is None else trg_degree
|
trg_degree = 0 if trg_degree is None else trg_degree
|
||||||
|
|
||||||
|
# 计算总度数并记录日志
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
|
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
|
||||||
@ -142,19 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
|
"""获取两个节点之间的边的属性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_node_id: 源节点ID
|
||||||
|
target_node_id: 目标节点ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 边的属性字典,如果边不存在则返回None
|
||||||
|
"""
|
||||||
|
# 清理节点ID中的引号
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_label (str): Label of the source nodes
|
|
||||||
target_node_label (str): Label of the target nodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of all relationships/edges found
|
|
||||||
"""
|
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
|
# 构建Cypher查询,获取边的属性
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||||
RETURN properties(r) as edge_properties
|
RETURN properties(r) as edge_properties
|
||||||
@ -164,9 +268,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
entity_name_label_target=entity_name_label_target,
|
entity_name_label_target=entity_name_label_target,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
if record:
|
if record:
|
||||||
|
# 转换结果为字典并记录日志
|
||||||
result = dict(record["edge_properties"])
|
result = dict(record["edge_properties"])
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||||
@ -176,40 +282,49 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||||
|
"""获取指定节点的所有边
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_node_id: 源节点ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[str, str]]: 边列表,每个元素为(源节点标签, 目标节点标签)的元组
|
||||||
|
"""
|
||||||
node_label = source_node_id.strip('"')
|
node_label = source_node_id.strip('"')
|
||||||
|
|
||||||
"""
|
# 构建Cypher查询,获取节点及其所有关系
|
||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
|
||||||
:return: List of dictionaries containing edge information
|
|
||||||
"""
|
|
||||||
query = f"""MATCH (n:`{node_label}`)
|
query = f"""MATCH (n:`{node_label}`)
|
||||||
OPTIONAL MATCH (n)-[r]-(connected)
|
OPTIONAL MATCH (n)-[r]-(connected)
|
||||||
RETURN n, r, connected"""
|
RETURN n, r, connected"""
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
results = await session.run(query)
|
results = await session.run(query)
|
||||||
edges = []
|
edges = []
|
||||||
|
# 异步迭代处理查询结果
|
||||||
async for record in results:
|
async for record in results:
|
||||||
source_node = record["n"]
|
source_node = record["n"]
|
||||||
connected_node = record["connected"]
|
connected_node = record["connected"]
|
||||||
|
|
||||||
|
# 获取源节点标签(取第一个标签)
|
||||||
source_label = (
|
source_label = (
|
||||||
list(source_node.labels)[0] if source_node.labels else None
|
list(source_node.labels)[0] if source_node.labels else None
|
||||||
)
|
)
|
||||||
|
# 获取目标节点标签(取第一个标签)
|
||||||
target_label = (
|
target_label = (
|
||||||
list(connected_node.labels)[0]
|
list(connected_node.labels)[0]
|
||||||
if connected_node and connected_node.labels
|
if connected_node and connected_node.labels
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 如果源节点和目标节点都有标签,则添加到边列表
|
||||||
if source_label and target_label:
|
if source_label and target_label:
|
||||||
edges.append((source_label, target_label))
|
edges.append((source_label, target_label))
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3), # 最多重试3次
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
|
||||||
retry=retry_if_exception_type(
|
retry=retry_if_exception_type( # 指定需要重试的异常类型
|
||||||
(
|
(
|
||||||
neo4jExceptions.ServiceUnavailable,
|
neo4jExceptions.ServiceUnavailable,
|
||||||
neo4jExceptions.TransientError,
|
neo4jExceptions.TransientError,
|
||||||
@ -218,17 +333,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||||
"""
|
"""更新或插入节点
|
||||||
Upsert a node in the Neo4j database.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: The unique identifier for the node (used as label)
|
node_id: 节点的唯一标识符(用作标签)
|
||||||
node_data: Dictionary of node properties
|
node_data: 节点属性字典
|
||||||
"""
|
"""
|
||||||
label = node_id.strip('"')
|
label = node_id.strip('"')
|
||||||
properties = node_data
|
properties = node_data
|
||||||
|
|
||||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||||
|
"""执行节点更新/插入的内部函数"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MERGE (n:`{label}`)
|
MERGE (n:`{label}`)
|
||||||
SET n += $properties
|
SET n += $properties
|
||||||
@ -246,9 +361,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3), # 最多重试3次
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
|
||||||
retry=retry_if_exception_type(
|
retry=retry_if_exception_type( # 指定需要重试的异常类型
|
||||||
(
|
(
|
||||||
neo4jExceptions.ServiceUnavailable,
|
neo4jExceptions.ServiceUnavailable,
|
||||||
neo4jExceptions.TransientError,
|
neo4jExceptions.TransientError,
|
||||||
@ -259,19 +374,19 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||||
):
|
):
|
||||||
"""
|
"""更新或插入边及其属性
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node_id (str): Label of the source node (used as identifier)
|
source_node_id: 源节点标签(用作标识符)
|
||||||
target_node_id (str): Label of the target node (used as identifier)
|
target_node_id: 目标节点标签(用作标识符)
|
||||||
edge_data (dict): Dictionary of properties to set on the edge
|
edge_data: 边属性字典
|
||||||
"""
|
"""
|
||||||
source_node_label = source_node_id.strip('"')
|
source_node_label = source_node_id.strip('"')
|
||||||
target_node_label = target_node_id.strip('"')
|
target_node_label = target_node_id.strip('"')
|
||||||
edge_properties = edge_data
|
edge_properties = edge_data
|
||||||
|
|
||||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||||
|
"""执行边更新/插入的内部函数"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (source:`{source_node_label}`)
|
MATCH (source:`{source_node_label}`)
|
||||||
WITH source
|
WITH source
|
||||||
@ -293,4 +408,5 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
|
"""节点嵌入方法(未实际使用)"""
|
||||||
print("Implemented but never called.")
|
print("Implemented but never called.")
|
||||||
|
@ -3,12 +3,30 @@ LightRAG - 轻量级检索增强生成系统
|
|||||||
该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建
|
该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 导入异步IO模块,用于处理异步编程
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
# 导入操作系统接口模块,用于处理文件路径和环境变量
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict, dataclass, field
|
|
||||||
|
# 从dataclasses模块导入数据类相关工具
|
||||||
|
from dataclasses import (
|
||||||
|
asdict, # 将数据类实例转换为字典的函数
|
||||||
|
dataclass, # 数据类装饰器
|
||||||
|
field, # 用于定义数据类字段的函数
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入日期时间处理模块
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 从functools导入partial函数,用于创建偏函数
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type, cast
|
|
||||||
|
# 从typing模块导入类型提示工具
|
||||||
|
from typing import (
|
||||||
|
Type, # 用于类型注解中表示类型的类型
|
||||||
|
cast, # 用于类型转换的函数
|
||||||
|
)
|
||||||
|
|
||||||
# 导入LLM相关功能
|
# 导入LLM相关功能
|
||||||
from .llm import (
|
from .llm import (
|
||||||
|
@ -1,35 +1,51 @@
|
|||||||
import os
|
# 标准库导入
|
||||||
import copy
|
import os # 操作系统接口
|
||||||
from functools import lru_cache
|
import copy # 深浅拷贝功能
|
||||||
import json
|
from functools import lru_cache # 最近最少使用缓存装饰器
|
||||||
import aioboto3
|
import json # JSON数据处理
|
||||||
import aiohttp
|
import base64 # Base64编解码
|
||||||
import numpy as np
|
import struct # 处理二进制数据结构
|
||||||
import ollama
|
|
||||||
|
|
||||||
|
# 第三方异步库
|
||||||
|
import aioboto3 # AWS SDK的异步版本
|
||||||
|
import aiohttp # 异步HTTP客户端/服务器
|
||||||
|
import ollama # Ollama API客户端
|
||||||
|
|
||||||
|
# 数值计算和机器学习库
|
||||||
|
import numpy as np # 数值计算库
|
||||||
|
import torch # PyTorch深度学习框架
|
||||||
|
from transformers import ( # Hugging Face转换器库
|
||||||
|
AutoTokenizer, # 自动分词器
|
||||||
|
AutoModelForCausalLM, # 自动因果语言模型
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI相关导入
|
||||||
from openai import (
|
from openai import (
|
||||||
AsyncOpenAI,
|
AsyncOpenAI, # OpenAI异步客户端
|
||||||
APIConnectionError,
|
APIConnectionError, # API连接错误
|
||||||
RateLimitError,
|
RateLimitError, # 速率限制错误
|
||||||
Timeout,
|
Timeout, # 超时错误
|
||||||
AsyncAzureOpenAI,
|
AsyncAzureOpenAI, # Azure OpenAI异步客户端
|
||||||
)
|
)
|
||||||
|
|
||||||
import base64
|
# 重试机制相关导入
|
||||||
import struct
|
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry, # 重试装饰器
|
||||||
stop_after_attempt,
|
stop_after_attempt, # 最大重试次数
|
||||||
wait_exponential,
|
wait_exponential, # 指数退避等待
|
||||||
retry_if_exception_type,
|
retry_if_exception_type, # 基于异常类型的重试条件
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据验证和类型提示
|
||||||
|
from pydantic import BaseModel, Field # 数据验证模型
|
||||||
|
from typing import List, Dict, Callable, Any # 类型提示
|
||||||
|
|
||||||
|
# 本地模块导入
|
||||||
|
from .base import BaseKVStorage # 键值存储基类
|
||||||
|
from .utils import (
|
||||||
|
compute_args_hash, # 计算参数哈希值
|
||||||
|
wrap_embedding_func_with_attrs, # 包装嵌入函数的装饰器
|
||||||
)
|
)
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
import torch
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Dict, Callable, Any
|
|
||||||
from .base import BaseKVStorage
|
|
||||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
|
||||||
|
|
||||||
# 禁用并行化以避免tokenizers的并行化导致的问题
|
# 禁用并行化以避免tokenizers的并行化导致的问题
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
@ -1,30 +1,43 @@
|
|||||||
import asyncio
|
# 标准库导入
|
||||||
import json
|
import asyncio # 异步IO支持
|
||||||
import re
|
import json # JSON数据处理
|
||||||
from typing import Union
|
import re # 正则表达式支持
|
||||||
from collections import Counter, defaultdict
|
from typing import Union # 类型提示:联合类型
|
||||||
import warnings
|
from collections import (
|
||||||
|
Counter, # 计数器集合类
|
||||||
|
defaultdict # 带默认值的字典
|
||||||
|
)
|
||||||
|
import warnings # 警告控制
|
||||||
|
|
||||||
|
# 从本地utils模块导入工具函数
|
||||||
from .utils import (
|
from .utils import (
|
||||||
logger,
|
logger, # 日志记录器
|
||||||
clean_str,
|
clean_str, # 字符串清理函数
|
||||||
compute_mdhash_id,
|
compute_mdhash_id, # 计算MD5哈希ID
|
||||||
decode_tokens_by_tiktoken,
|
decode_tokens_by_tiktoken, # tiktoken解码函数
|
||||||
encode_string_by_tiktoken,
|
encode_string_by_tiktoken, # tiktoken编码函数
|
||||||
is_float_regex,
|
is_float_regex, # 浮点数检查函数
|
||||||
list_of_list_to_csv,
|
list_of_list_to_csv, # 列表转CSV函数
|
||||||
pack_user_ass_to_openai_messages,
|
pack_user_ass_to_openai_messages, # OpenAI消息打包函数
|
||||||
split_string_by_multi_markers,
|
split_string_by_multi_markers, # 多标记字符串分割函数
|
||||||
truncate_list_by_token_size,
|
truncate_list_by_token_size, # 基于token大小截断列表
|
||||||
process_combine_contexts,
|
process_combine_contexts, # 上下文合并处理函数
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 从本地base模块导入基础类
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage, # 图存储基类
|
||||||
BaseKVStorage,
|
BaseKVStorage, # 键值存储基类
|
||||||
BaseVectorStorage,
|
BaseVectorStorage, # 向量存储基类
|
||||||
TextChunkSchema,
|
TextChunkSchema, # 文本块模式定义
|
||||||
QueryParam,
|
QueryParam, # 查询参数类
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从本地prompt模块导入提示相关常量
|
||||||
|
from .prompt import (
|
||||||
|
GRAPH_FIELD_SEP, # 图字段分隔符
|
||||||
|
PROMPTS # 提示模板集合
|
||||||
)
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
|
||||||
|
|
||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
|
@ -1,23 +1,34 @@
|
|||||||
import asyncio
|
# 标准库导入
|
||||||
import html
|
import asyncio # 异步IO支持
|
||||||
import os
|
import html # HTML实体编解码
|
||||||
from dataclasses import dataclass
|
import os # 操作系统接口,用于文件和路径操作
|
||||||
from typing import Any, Union, cast
|
|
||||||
import networkx as nx
|
|
||||||
import numpy as np
|
|
||||||
from nano_vectordb import NanoVectorDB
|
|
||||||
|
|
||||||
from .utils import (
|
# 数据类和类型提示相关导入
|
||||||
logger,
|
from dataclasses import dataclass # 数据类装饰器
|
||||||
load_json,
|
from typing import (
|
||||||
write_json,
|
Any, # 任意类型
|
||||||
compute_mdhash_id,
|
Union, # 联合类型
|
||||||
|
cast # 类型转换函数
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 第三方库导入
|
||||||
|
import networkx as nx # 图数据处理库
|
||||||
|
import numpy as np # 数值计算库
|
||||||
|
from nano_vectordb import NanoVectorDB # 向量数据库
|
||||||
|
|
||||||
|
# 从本地utils模块导入工具函数
|
||||||
|
from .utils import (
|
||||||
|
logger, # 日志记录器
|
||||||
|
load_json, # JSON文件加载函数
|
||||||
|
write_json, # JSON文件写入函数
|
||||||
|
compute_mdhash_id, # 计算MD5哈希ID的函数
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从本地base模块导入基础存储类
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage, # 图存储基类
|
||||||
BaseKVStorage,
|
BaseKVStorage, # 键值存储基类
|
||||||
BaseVectorStorage,
|
BaseVectorStorage, # 向量存储基类
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,19 +1,29 @@
|
|||||||
import asyncio
|
# 标准库导入 - 异步和IO操作
|
||||||
import html
|
import asyncio # 异步IO支持
|
||||||
import io
|
import html # HTML实体编解码
|
||||||
import csv
|
import io # 内存IO操作
|
||||||
import json
|
import csv # CSV文件处理
|
||||||
import logging
|
import json # JSON数据处理
|
||||||
import os
|
import logging # 日志记录
|
||||||
import re
|
import os # 操作系统接口
|
||||||
from dataclasses import dataclass
|
import re # 正则表达式
|
||||||
from functools import wraps
|
|
||||||
from hashlib import md5
|
|
||||||
from typing import Any, Union, List
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
|
|
||||||
import numpy as np
|
# 标准库导入 - 数据结构和工具
|
||||||
import tiktoken
|
from dataclasses import dataclass # 数据类装饰器
|
||||||
|
from functools import wraps # 装饰器工具
|
||||||
|
from hashlib import md5 # MD5哈希算法
|
||||||
|
from typing import ( # 类型提示
|
||||||
|
Any, # 任意类型
|
||||||
|
Union, # 联合类型
|
||||||
|
List # 列表类型
|
||||||
|
)
|
||||||
|
|
||||||
|
# XML处理
|
||||||
|
import xml.etree.ElementTree as ET # XML解析和处理
|
||||||
|
|
||||||
|
# 第三方库导入
|
||||||
|
import numpy as np # 数值计算库
|
||||||
|
import tiktoken # OpenAI的分词器
|
||||||
|
|
||||||
# 全局编码器变量
|
# 全局编码器变量
|
||||||
ENCODER = None
|
ENCODER = None
|
||||||
|
Loading…
Reference in New Issue
Block a user