lightrag-comments/lightrag/kg/neo4j_impl.py

413 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 标准库导入
import asyncio # 异步IO支持
import os # 操作系统接口,用于环境变量访问
import inspect # 用于运行时检查Python对象
# 数据类和类型提示相关导入
from dataclasses import dataclass # 数据类装饰器
from typing import (
Any, # 任意类型
Union, # 联合类型
Tuple, # 元组类型
List, # 列表类型
Dict # 字典类型
)
# 本地模块导入
from lightrag.utils import logger # 日志记录器
from ..base import BaseGraphStorage # 图存储基类
# Neo4j相关导入
from neo4j import (
AsyncGraphDatabase, # Neo4j异步图数据库驱动
exceptions as neo4jExceptions, # Neo4j异常类
AsyncDriver, # Neo4j异步驱动接口
AsyncManagedTransaction, # Neo4j异步事务管理
)
# 重试机制相关导入
from tenacity import (
retry, # 重试装饰器
stop_after_attempt, # 最大重试次数限制
wait_exponential, # 指数退避等待策略
retry_if_exception_type, # 基于异常类型的重试条件
)
@dataclass
class Neo4JStorage(BaseGraphStorage):
"""Neo4j图数据库存储实现类"""
@staticmethod
def load_nx_graph(file_name):
"""加载NetworkX图的静态方法生产环境中未使用
Args:
file_name: 图文件名
"""
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
"""初始化Neo4j存储实例
Args:
namespace: 命名空间
global_config: 全局配置
Note:
从环境变量中读取Neo4j连接信息并初始化驱动
"""
# 调用父类初始化
super().__init__(namespace=namespace, global_config=global_config)
# 初始化驱动相关属性
self._driver = None
self._driver_lock = asyncio.Lock() # 异步锁,用于并发控制
# 从环境变量获取Neo4j连接信息
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
# 初始化Neo4j异步驱动
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None
def __post_init__(self):
"""数据类后初始化方法,设置节点嵌入算法"""
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
"""关闭数据库连接"""
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
"""异步上下文管理器的退出方法"""
if self._driver:
await self._driver.close()
async def index_done_callback(self):
"""索引完成回调方法"""
print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
"""检查节点是否存在
Args:
node_id: 节点ID
Returns:
bool: 节点是否存在
"""
# 清理节点ID中的引号
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
# 执行查询
result = await session.run(query)
single_result = await result.single()
# 记录调试日志
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""检查边是否存在
Args:
source_node_id: 源节点ID
target_node_id: 目标节点ID
Returns:
bool: 边是否存在
"""
# 清理节点ID中的引号
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
# 执行查询
result = await session.run(query)
single_result = await result.single()
# 记录调试日志
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
return single_result["edgeExists"]
def close(self):
"""同步关闭方法(注意:这是一个缩进错误,应该与其他方法对齐)"""
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]:
"""获取节点信息
Args:
node_id: 节点ID
Returns:
dict: 节点属性字典如果节点不存在则返回None
"""
async with self._driver.session() as session:
# 清理节点ID中的引号
entity_name_label = node_id.strip('"')
# 构建Cypher查询
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
# 执行查询
result = await session.run(query)
record = await result.single()
if record:
# 提取节点数据并转换为字典
node = record["n"]
node_dict = dict(node)
# 记录调试日志
logger.debug(
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
"""获取节点的度(与节点相连的边的数量)
Args:
node_id: 节点ID
Returns:
int: 节点的度如果节点不存在则返回None
"""
# 清理节点ID中的引号
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询计算节点的总边数
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
"""
result = await session.run(query)
record = await result.single()
if record:
edge_count = record["totalEdgeCount"]
# 记录调试日志
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
)
return edge_count
else:
return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""计算边的度(源节点和目标节点的度之和)
Args:
src_id: 源节点ID
tgt_id: 目标节点ID
Returns:
int: 边的度(两个节点的度之和)
"""
# 清理节点ID中的引号
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
# 获取源节点和目标节点的度
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# 将None转换为0以进行加法运算
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
# 计算总度数并记录日志
degrees = int(src_degree) + int(trg_degree)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
)
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""获取两个节点之间的边的属性
Args:
source_node_id: 源节点ID
target_node_id: 目标节点ID
Returns:
dict: 边的属性字典如果边不存在则返回None
"""
# 清理节点ID中的引号
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
# 构建Cypher查询获取边的属性
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
# 执行查询
result = await session.run(query)
record = await result.single()
if record:
# 转换结果为字典并记录日志
result = dict(record["edge_properties"])
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
else:
return None
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
"""获取指定节点的所有边
Args:
source_node_id: 源节点ID
Returns:
List[Tuple[str, str]]: 边列表,每个元素为(源节点标签, 目标节点标签)的元组
"""
node_label = source_node_id.strip('"')
# 构建Cypher查询获取节点及其所有关系
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
results = await session.run(query)
edges = []
# 异步迭代处理查询结果
async for record in results:
source_node = record["n"]
connected_node = record["connected"]
# 获取源节点标签(取第一个标签)
source_label = (
list(source_node.labels)[0] if source_node.labels else None
)
# 获取目标节点标签(取第一个标签)
target_label = (
list(connected_node.labels)[0]
if connected_node and connected_node.labels
else None
)
# 如果源节点和目标节点都有标签,则添加到边列表
if source_label and target_label:
edges.append((source_label, target_label))
return edges
@retry(
stop=stop_after_attempt(3), # 最多重试3次
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
retry=retry_if_exception_type( # 指定需要重试的异常类型
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""更新或插入节点
Args:
node_id: 节点的唯一标识符(用作标签)
node_data: 节点属性字典
"""
label = node_id.strip('"')
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
"""执行节点更新/插入的内部函数"""
query = f"""
MERGE (n:`{label}`)
SET n += $properties
"""
await tx.run(query, properties=properties)
logger.debug(
f"Upserted node with label '{label}' and properties: {properties}"
)
try:
async with self._driver.session() as session:
await session.execute_write(_do_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3), # 最多重试3次
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
retry=retry_if_exception_type( # 指定需要重试的异常类型
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""更新或插入边及其属性
Args:
source_node_id: 源节点标签(用作标识符)
target_node_id: 目标节点标签(用作标识符)
edge_data: 边属性字典
"""
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
"""执行边更新/插入的内部函数"""
query = f"""
MATCH (source:`{source_node_label}`)
WITH source
MATCH (target:`{target_node_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += $properties
RETURN r
"""
await tx.run(query, properties=edge_properties)
logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
)
try:
async with self._driver.session() as session:
await session.execute_write(_do_upsert_edge)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def _node2vec_embed(self):
"""节点嵌入方法(未实际使用)"""
print("Implemented but never called.")