# 标准库导入 import asyncio # 异步IO支持 import os # 操作系统接口,用于环境变量访问 import inspect # 用于运行时检查Python对象 # 数据类和类型提示相关导入 from dataclasses import dataclass # 数据类装饰器 from typing import ( Any, # 任意类型 Union, # 联合类型 Tuple, # 元组类型 List, # 列表类型 Dict # 字典类型 ) # 本地模块导入 from lightrag.utils import logger # 日志记录器 from ..base import BaseGraphStorage # 图存储基类 # Neo4j相关导入 from neo4j import ( AsyncGraphDatabase, # Neo4j异步图数据库驱动 exceptions as neo4jExceptions, # Neo4j异常类 AsyncDriver, # Neo4j异步驱动接口 AsyncManagedTransaction, # Neo4j异步事务管理 ) # 重试机制相关导入 from tenacity import ( retry, # 重试装饰器 stop_after_attempt, # 最大重试次数限制 wait_exponential, # 指数退避等待策略 retry_if_exception_type, # 基于异常类型的重试条件 ) @dataclass class Neo4JStorage(BaseGraphStorage): """Neo4j图数据库存储实现类""" @staticmethod def load_nx_graph(file_name): """加载NetworkX图的静态方法(生产环境中未使用) Args: file_name: 图文件名 """ print("no preloading of graph with neo4j in production") def __init__(self, namespace, global_config): """初始化Neo4j存储实例 Args: namespace: 命名空间 global_config: 全局配置 Note: 从环境变量中读取Neo4j连接信息并初始化驱动 """ # 调用父类初始化 super().__init__(namespace=namespace, global_config=global_config) # 初始化驱动相关属性 self._driver = None self._driver_lock = asyncio.Lock() # 异步锁,用于并发控制 # 从环境变量获取Neo4j连接信息 URI = os.environ["NEO4J_URI"] USERNAME = os.environ["NEO4J_USERNAME"] PASSWORD = os.environ["NEO4J_PASSWORD"] # 初始化Neo4j异步驱动 self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) return None def __post_init__(self): """数据类后初始化方法,设置节点嵌入算法""" self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } async def close(self): """关闭数据库连接""" if self._driver: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): """异步上下文管理器的退出方法""" if self._driver: await self._driver.close() async def index_done_callback(self): """索引完成回调方法""" print("KG successfully indexed.") async def has_node(self, node_id: str) -> bool: """检查节点是否存在 Args: node_id: 节点ID Returns: bool: 节点是否存在 """ # 清理节点ID中的引号 entity_name_label = node_id.strip('"') async with self._driver.session() as session: # 构建Cypher查询 query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) # 执行查询 result = await session.run(query) single_result = await result.single() # 记录调试日志 logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' ) return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """检查边是否存在 Args: source_node_id: 源节点ID target_node_id: 目标节点ID Returns: bool: 边是否存在 """ # 清理节点ID中的引号 entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') async with self._driver.session() as session: # 构建Cypher查询 query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" ) # 执行查询 result = await session.run(query) single_result = await result.single() # 记录调试日志 logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' ) return single_result["edgeExists"] def close(self): """同步关闭方法(注意:这是一个缩进错误,应该与其他方法对齐)""" self._driver.close() async def get_node(self, node_id: str) -> Union[dict, None]: """获取节点信息 Args: node_id: 节点ID Returns: dict: 节点属性字典,如果节点不存在则返回None """ async with self._driver.session() as session: # 清理节点ID中的引号 entity_name_label = node_id.strip('"') # 构建Cypher查询 query = f"MATCH (n:`{entity_name_label}`) RETURN n" # 执行查询 result = await session.run(query) record = await result.single() if record: # 提取节点数据并转换为字典 node = record["n"] node_dict = dict(node) # 记录调试日志 logger.debug( f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" ) return node_dict return None async def node_degree(self, node_id: str) -> int: """获取节点的度(与节点相连的边的数量) Args: node_id: 节点ID Returns: int: 节点的度,如果节点不存在则返回None """ # 清理节点ID中的引号 entity_name_label = node_id.strip('"') async with self._driver.session() as session: # 构建Cypher查询,计算节点的总边数 query = f""" MATCH (n:`{entity_name_label}`) RETURN COUNT{{ (n)--() }} AS totalEdgeCount """ result = await session.run(query) record = await result.single() if record: edge_count = record["totalEdgeCount"] # 记录调试日志 logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" ) return edge_count else: return None async def edge_degree(self, src_id: str, tgt_id: str) -> int: """计算边的度(源节点和目标节点的度之和) Args: src_id: 源节点ID tgt_id: 目标节点ID Returns: int: 边的度(两个节点的度之和) """ # 清理节点ID中的引号 entity_name_label_source = src_id.strip('"') entity_name_label_target = tgt_id.strip('"') # 获取源节点和目标节点的度 src_degree = await self.node_degree(entity_name_label_source) trg_degree = await self.node_degree(entity_name_label_target) # 将None转换为0以进行加法运算 src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree # 计算总度数并记录日志 degrees = int(src_degree) + int(trg_degree) logger.debug( f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}" ) return degrees async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """获取两个节点之间的边的属性 Args: source_node_id: 源节点ID target_node_id: 目标节点ID Returns: dict: 边的属性字典,如果边不存在则返回None """ # 清理节点ID中的引号 entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') async with self._driver.session() as session: # 构建Cypher查询,获取边的属性 query = f""" MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties LIMIT 1 """.format( entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target, ) # 执行查询 result = await session.run(query) record = await result.single() if record: # 转换结果为字典并记录日志 result = dict(record["edge_properties"]) logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" ) return result else: return None async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: """获取指定节点的所有边 Args: source_node_id: 源节点ID Returns: List[Tuple[str, str]]: 边列表,每个元素为(源节点标签, 目标节点标签)的元组 """ node_label = source_node_id.strip('"') # 构建Cypher查询,获取节点及其所有关系 query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" async with self._driver.session() as session: results = await session.run(query) edges = [] # 异步迭代处理查询结果 async for record in results: source_node = record["n"] connected_node = record["connected"] # 获取源节点标签(取第一个标签) source_label = ( list(source_node.labels)[0] if source_node.labels else None ) # 获取目标节点标签(取第一个标签) target_label = ( list(connected_node.labels)[0] if connected_node and connected_node.labels else None ) # 如果源节点和目标节点都有标签,则添加到边列表 if source_label and target_label: edges.append((source_label, target_label)) return edges @retry( stop=stop_after_attempt(3), # 最多重试3次 wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待 retry=retry_if_exception_type( # 指定需要重试的异常类型 ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, ) ), ) async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): """更新或插入节点 Args: node_id: 节点的唯一标识符(用作标签) node_data: 节点属性字典 """ label = node_id.strip('"') properties = node_data async def _do_upsert(tx: AsyncManagedTransaction): """执行节点更新/插入的内部函数""" query = f""" MERGE (n:`{label}`) SET n += $properties """ await tx.run(query, properties=properties) logger.debug( f"Upserted node with label '{label}' and properties: {properties}" ) try: async with self._driver.session() as session: await session.execute_write(_do_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") raise @retry( stop=stop_after_attempt(3), # 最多重试3次 wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待 retry=retry_if_exception_type( # 指定需要重试的异常类型 ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, ) ), ) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] ): """更新或插入边及其属性 Args: source_node_id: 源节点标签(用作标识符) target_node_id: 目标节点标签(用作标识符) edge_data: 边属性字典 """ source_node_label = source_node_id.strip('"') target_node_label = target_node_id.strip('"') edge_properties = edge_data async def _do_upsert_edge(tx: AsyncManagedTransaction): """执行边更新/插入的内部函数""" query = f""" MATCH (source:`{source_node_label}`) WITH source MATCH (target:`{target_node_label}`) MERGE (source)-[r:DIRECTED]->(target) SET r += $properties RETURN r """ await tx.run(query, properties=edge_properties) logger.debug( f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}" ) try: async with self._driver.session() as session: await session.execute_write(_do_upsert_edge) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") raise async def _node2vec_embed(self): """节点嵌入方法(未实际使用)""" print("Implemented but never called.")