diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8b29974..e7ccec2 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,3 +1,8 @@ +""" +LightRAG - 轻量级检索增强生成系统 +该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建 +""" + import asyncio import os from dataclasses import asdict, dataclass, field @@ -5,144 +10,170 @@ from datetime import datetime from functools import partial from typing import Type, cast +# 导入LLM相关功能 from .llm import ( - gpt_4o_mini_complete, - openai_embedding, + gpt_4o_mini_complete, # GPT模型完成功能 + openai_embedding, # OpenAI文本嵌入功能 ) + +# 导入核心操作功能 from .operate import ( - chunking_by_token_size, - extract_entities, - local_query, - global_query, - hybrid_query, - naive_query, + chunking_by_token_size, # 文本分块 + extract_entities, # 实体提取 + local_query, # 本地查询 + global_query, # 全局查询 + hybrid_query, # 混合查询 + naive_query, # 简单查询 ) +# 导入存储实现 from .storage import ( - JsonKVStorage, - NanoVectorDBStorage, - NetworkXStorage, + JsonKVStorage, # JSON键值存储 + NanoVectorDBStorage, # 向量数据库存储 + NetworkXStorage, # 图数据库存储 ) -from .kg.neo4j_impl import Neo4JStorage -# future KG integrations - +from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现 +# 未来可能的图数据库集成 # from .kg.ArangoDB_impl import ( # GraphStorage as ArangoDBStorage # ) - +# 导入工具函数 from .utils import ( - EmbeddingFunc, - compute_mdhash_id, - limit_async_func_call, - convert_response_to_json, - logger, - set_logger, + EmbeddingFunc, # 嵌入函数类型 + compute_mdhash_id, # 计算MD5哈希ID + limit_async_func_call, # 限制异步函数调用 + convert_response_to_json, # 响应转JSON + logger, # 日志记录器 + set_logger, # 设置日志 ) + +# 导入基类 from .base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, - StorageNameSpace, - QueryParam, + BaseGraphStorage, # 图存储基类 + BaseKVStorage, # 键值存储基类 + BaseVectorStorage, # 向量存储基类 + StorageNameSpace, # 存储命名空间 + QueryParam, # 查询参数 ) def always_get_an_event_loop() -> asyncio.AbstractEventLoop: + """ + 获取或创建事件循环 + 如果当前线程没有事件循环,则创建一个新的 + 返回值: + asyncio.AbstractEventLoop: 事件循环实例 + """ try: return asyncio.get_event_loop() - except RuntimeError: logger.info("Creating a new event loop in main thread.") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - return loop @dataclass class LightRAG: + """ + 轻量级检索增强生成(LightRAG)系统的主类 + 实现了文档的存储、检索、知识图谱构建和问答功能 + """ + + # 工作目录配置,用存储所有缓存文件 working_dir: str = field( default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) + # 知识图谱存储类型,默认使用NetworkX实现 kg: str = field(default="NetworkXStorage") + # 日志级别设置 current_log_level = logger.level log_level: str = field(default=current_log_level) - # text chunking - chunk_token_size: int = 1200 - chunk_overlap_token_size: int = 100 - tiktoken_model_name: str = "gpt-4o-mini" + # 文本分块参数配置 + chunk_token_size: int = 1200 # 每个文本块的目标token数 + chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数 + tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称 - # entity extraction - entity_extract_max_gleaning: int = 1 - entity_summary_to_max_tokens: int = 500 + # 实体提取参数 + entity_extract_max_gleaning: int = 1 # 最大实体提取次数 + entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数 - # node embedding - node_embedding_algorithm: str = "node2vec" + # 节点嵌入配置 + node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择 node2vec_params: dict = field( default_factory=lambda: { - "dimensions": 1536, - "num_walks": 10, - "walk_length": 40, - "window_size": 2, - "iterations": 3, - "random_seed": 3, + "dimensions": 1536, # 嵌入向量维度 + "num_walks": 10, # 每个节点的随机游走次数 + "walk_length": 40, # 每次随机游走的长度 + "window_size": 2, # 上下文窗口大小 + "iterations": 3, # 训练迭代次数 + "random_seed": 3, # 随机种子 } ) - # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) - embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) - embedding_batch_num: int = 32 - embedding_func_max_async: int = 16 + # 文本嵌入配置 + embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型 + embedding_batch_num: int = 32 # 批处理大小 + embedding_func_max_async: int = 16 # 最大并发请求数 - # LLM - llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' - llm_model_max_token_size: int = 32768 - llm_model_max_async: int = 16 - llm_model_kwargs: dict = field(default_factory=dict) + # 语言模型配置 + llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型 + llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称 + llm_model_max_token_size: int = 32768 # 模型最大token限制 + llm_model_max_async: int = 16 # 最大并发请求数 + llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数 - # storage - key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage - vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage - vector_db_storage_cls_kwargs: dict = field(default_factory=dict) - enable_llm_cache: bool = True + # 存储配置 + key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类 + vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类 + vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数 + enable_llm_cache: bool = True # 是否启用语言模型缓存 - # extension - addon_params: dict = field(default_factory=dict) - convert_response_to_json_func: callable = convert_response_to_json + # 扩展配置 + addon_params: dict = field(default_factory=dict) # 附加参数 + convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数 def __post_init__(self): + """ + 初始化方法,在对象创建后自动调用 + 负责设置日志、初始化存储系统和配置各种功能组件 + """ + # 配置日志系统 log_file = os.path.join(self.working_dir, "lightrag.log") set_logger(log_file) logger.setLevel(self.log_level) - logger.info(f"Logger initialized for working directory: {self.working_dir}") + # 记录初始化参数 _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") - # @TODO: should move all storage setup here to leverage initial start params attached to self. + # 根据配置选择图存储实现类 self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ self.kg ] + # 确保工作目录存在 if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) + # 初始化文档存储系统 self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", global_config=asdict(self) ) + # 初始化文本块存储系统 self.text_chunks = self.key_string_value_json_storage_cls( namespace="text_chunks", global_config=asdict(self) ) + # 初始化语言模型响应缓存(如果启用) self.llm_response_cache = ( self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self) @@ -150,32 +181,40 @@ class LightRAG: if self.enable_llm_cache else None ) + + # 初始化实体关系图存储 self.chunk_entity_relation_graph = self.graph_storage_cls( namespace="chunk_entity_relation", global_config=asdict(self) ) + # 配置嵌入函数的并发限制 self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) + # 初始化向量数据库存储系统 + # 用于存储实体的向量表示 self.entities_vdb = self.vector_db_storage_cls( namespace="entities", global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) + # 用于存储关系的向量表示 self.relationships_vdb = self.vector_db_storage_cls( namespace="relationships", global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) + # 用于存储文本块的向量表示 self.chunks_vdb = self.vector_db_storage_cls( namespace="chunks", global_config=asdict(self), embedding_func=self.embedding_func, ) + # 配置语言模型函数的并发限制和缓存 self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, @@ -185,33 +224,62 @@ class LightRAG: ) def _get_storage_class(self) -> Type[BaseGraphStorage]: + """ + 获取图存储类的实现 + 根据配置选择合适的图存储后端(Neo4J或NetworkX) + + 返回值: + Type[BaseGraphStorage]: 图存储类 + """ return { "Neo4JStorage": Neo4JStorage, "NetworkXStorage": NetworkXStorage, } def insert(self, string_or_strings): + """ + 同步方式插入文档 + 将字符串或字符串列表插入到系统中进行处理 + + 参数: + string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容 + """ loop = always_get_an_event_loop() return loop.run_until_complete(self.ainsert(string_or_strings)) async def ainsert(self, string_or_strings): + """ + 异步方式插入文档 + 处理文档内容,包括分块、实体提取和向量化存储 + + 参数: + string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容 + """ try: + # 确保输入是列表形式 if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] + # 为每个文档生成唯一ID并创建文档字典 new_docs = { compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} for c in string_or_strings } + + # 过滤掉已存在的文档 _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + if not len(new_docs): logger.warning("All docs are already in the storage") return + logger.info(f"[New Docs] inserting {len(new_docs)} docs") + # 处理文档分块 inserting_chunks = {} for doc_key, doc in new_docs.items(): + # 对每个文档进行分块,并为每个块生成唯一ID chunks = { compute_mdhash_id(dp["content"], prefix="chunk-"): { **dp, @@ -225,19 +293,25 @@ class LightRAG: ) } inserting_chunks.update(chunks) + + # 过滤掉已存在的文本块 _add_chunk_keys = await self.text_chunks.filter_keys( list(inserting_chunks.keys()) ) inserting_chunks = { k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } + if not len(inserting_chunks): logger.warning("All chunks are already in the storage") return + logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") + # 将新的文本块插入向量数据库 await self.chunks_vdb.upsert(inserting_chunks) + # 提取实体和关系并更新知识图谱 logger.info("[Entity Extraction]...") maybe_new_kg = await extract_entities( inserting_chunks, @@ -246,38 +320,70 @@ class LightRAG: relationships_vdb=self.relationships_vdb, global_config=asdict(self), ) + if maybe_new_kg is None: logger.warning("No new entities and relationships found") return + self.chunk_entity_relation_graph = maybe_new_kg + # 更新文档和文本块存储 await self.full_docs.upsert(new_docs) await self.text_chunks.upsert(inserting_chunks) finally: + # 完成插入后执行清理工作 await self._insert_done() async def _insert_done(self): + """ + 插入操作完成后的回调函数 + 负责更新所有存储实例的索引并保存状态 + """ tasks = [] + # 遍历所有需要执行回调的存储实例 for storage_inst in [ - self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.entities_vdb, - self.relationships_vdb, - self.chunks_vdb, - self.chunk_entity_relation_graph, + self.full_docs, # 完整文档存储 + self.text_chunks, # 文本块存储 + self.llm_response_cache, # LLM响应缓存 + self.entities_vdb, # 实体向量数据库 + self.relationships_vdb, # 关系向量数据库 + self.chunks_vdb, # 文本块向量���据库 + self.chunk_entity_relation_graph, # 实体关系图 ]: if storage_inst is None: continue + # 将每个存储实例的回调任务添加到任务列表 tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) + # 并发执行所有回调任务 await asyncio.gather(*tasks) def query(self, query: str, param: QueryParam = QueryParam()): + """ + 同步方式执行查询 + + 参数: + query: 查询文本 + param: 查询参数配置 + 返回值: + 查询结果 + """ loop = always_get_an_event_loop() return loop.run_until_complete(self.aquery(query, param)) async def aquery(self, query: str, param: QueryParam = QueryParam()): + """ + 异步方式执行查询 + 支持多种查询模式:本地查询、全局查询、混合查询和简单查询 + + 参数: + query: 查询文本 + param: 查询参数配置 + 返回值: + 查询结果 + """ + # 根据查询模式选择相应的查询方法 if param.mode == "local": + # 本地查询:主要基于局部上下文 response = await local_query( query, self.chunk_entity_relation_graph, @@ -288,6 +394,7 @@ class LightRAG: asdict(self), ) elif param.mode == "global": + # 全局查询:考虑整个知识图谱 response = await global_query( query, self.chunk_entity_relation_graph, @@ -298,6 +405,7 @@ class LightRAG: asdict(self), ) elif param.mode == "hybrid": + # 混合查询:结合局部和全局信息 response = await hybrid_query( query, self.chunk_entity_relation_graph, @@ -308,6 +416,7 @@ class LightRAG: asdict(self), ) elif param.mode == "naive": + # 简单查询:直接基于文本相似度 response = await naive_query( query, self.chunks_vdb, @@ -317,38 +426,73 @@ class LightRAG: ) else: raise ValueError(f"Unknown mode {param.mode}") + + # 执行查询完成后的清理工作 await self._query_done() return response async def _query_done(self): + """ + 查询操作完成后的回调函数 + 主要用于更新LLM响应缓存的状态 + """ tasks = [] + # 目前只需要处理LLM响应缓存的回调 for storage_inst in [self.llm_response_cache]: if storage_inst is None: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) + # 并发执行所有回调任务 await asyncio.gather(*tasks) def delete_by_entity(self, entity_name: str): + """ + 同步方式删除指定实体 + + 参数: + entity_name: 要删除的实体名称 + """ loop = always_get_an_event_loop() return loop.run_until_complete(self.adelete_by_entity(entity_name)) async def adelete_by_entity(self, entity_name: str): + """ + 异步方式删除指定实体及其相关的所有信息 + + 参数: + entity_name: 要删除的实体名称 + """ + # 标准化实体名称(转为大写并添加引号) entity_name = f'"{entity_name.upper()}"' - try: + # 依次删除实体在各个存储中的数据: + # 1. 从实体向量数据库中删除 await self.entities_vdb.delete_entity(entity_name) + # 2. 从关系向量数据库中删除相关关系 await self.relationships_vdb.delete_relation(entity_name) + # 3. 从知识图谱中删除节点 await self.chunk_entity_relation_graph.delete_node(entity_name) - + + # 记录删除成功的日志 logger.info( f"Entity '{entity_name}' and its relationships have been deleted." ) + # 执行删除完成后的清理工作 await self._delete_by_entity_done() except Exception as e: + # 记录删除过程中的错误 logger.error(f"Error while deleting entity '{entity_name}': {e}") async def _delete_by_entity_done(self): + """ + 实体删除操作完成后的回调函数 + 负责更新所有相关存储实例的状态 + """ tasks = [] + # 遍历需要更新的存储实例: + # - 实体向量数据库 + # - 关系向量数据库 + # - 实体关系图 for storage_inst in [ self.entities_vdb, self.relationships_vdb, @@ -356,5 +500,11 @@ class LightRAG: ]: if storage_inst is None: continue + # 将每个存储实例的回调添加到任务列表 tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) + # 并发执行所有回调任务 await asyncio.gather(*tasks) + + + + diff --git a/lightrag/llm.py b/lightrag/llm.py index f4045e8..eb0a067 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -31,9 +31,10 @@ from typing import List, Dict, Callable, Any from .base import BaseKVStorage from .utils import compute_args_hash, wrap_embedding_func_with_attrs +# 禁用并行化以避免tokenizers的并行化导致的问题 os.environ["TOKENIZERS_PARALLELISM"] = "false" - +# 使用retry装饰器处理重试逻辑,处理OpenAI API的速率限制、连接和超时错误 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -48,35 +49,64 @@ async def openai_complete_if_cache( api_key=None, **kwargs, ) -> str: + """ + 异步函数,通过OpenAI的API获取语言模型的补全结果,支持缓存机制。 + + 参数: + - model: 使用的模型名称 + - prompt: 用户输入的提示 + - system_prompt: 系统提示(可选) + - history_messages: 历史消息(可选) + - base_url: API的基础URL(可选) + - api_key: API密钥(可选) + - **kwargs: 其他参数 + + 返回: + - str: 模型生成的文本 + """ + # 设置环境变量中的API密钥 if api_key: os.environ["OPENAI_API_KEY"] = api_key + # 初始化OpenAI异步客户端 openai_async_client = ( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) + + # 初始化哈希存储和消息列表 hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] + + # 添加系统提示到消息列表 if system_prompt: messages.append({"role": "system", "content": system_prompt}) + + # 将历史消息和当前提示添加到消息列表 messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + + # 检查缓存中是否有结果 if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] + # 调用OpenAI API获取补全结果 response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) + # 将结果缓存 if hashing_kv is not None: await hashing_kv.upsert( {args_hash: {"return": response.choices[0].message.content, "model": model}} ) + + # 返回生成的文本 return response.choices[0].message.content - +# 与openai_complete_if_cache类似的函数,但用于Azure OpenAI服务 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -91,45 +121,71 @@ async def azure_openai_complete_if_cache( api_key=None, **kwargs, ): + """ + 异步函数,通过Azure OpenAI的API获取语言模型的补全结果,支持缓存机制。 + + 参数: + - model: 使用的模型名称 + - prompt: 用户输入的提示 + - system_prompt: 系统提示(可选) + - history_messages: 历史消息(可选) + - base_url: API的基础URL(可选) + - api_key: API密钥(可选) + - **kwargs: 其他参数 + + 返回: + - str: 模型生成的文本 + """ + # 设置环境变量中的API密钥和端点 if api_key: os.environ["AZURE_OPENAI_API_KEY"] = api_key if base_url: os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + # 初始化Azure OpenAI异步客户端 openai_async_client = AsyncAzureOpenAI( azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) + # 初始化哈希存储和消息列表 hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] + + # 添加系统提示到消息列表 if system_prompt: messages.append({"role": "system", "content": system_prompt}) + + # 将历史消息和当前提示添加到消息列表 messages.extend(history_messages) if prompt is not None: messages.append({"role": "user", "content": prompt}) + + # 检查缓存中是否有结果 if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] + # 调用Azure OpenAI API获取补全结果 response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) + # 将结果缓存 if hashing_kv is not None: await hashing_kv.upsert( {args_hash: {"return": response.choices[0].message.content, "model": model}} ) + + # 返回生成的文本 return response.choices[0].message.content class BedrockError(Exception): - """Generic error for issues related to Amazon Bedrock""" - - + """Amazon Bedrock 相关问题的通用错误""" @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=60), @@ -145,6 +201,25 @@ async def bedrock_complete_if_cache( aws_session_token=None, **kwargs, ) -> str: + """ + 异步使用 Amazon Bedrock 完成文本生成,支持缓存。 + + 如果缓存命中,则直接返回缓存结果。该函数在失败时支持重试。 + + 参数: + - model: 要使用的 Bedrock 模型的模型 ID。 + - prompt: 用户输入的提示。 + - system_prompt: 系统提示,如果有。 + - history_messages: 会话历史消息列表,用于对话上下文。 + - aws_access_key_id: AWS 访问密钥 ID。 + - aws_secret_access_key: AWS 秘密访问密钥。 + - aws_session_token: AWS 会话令牌。 + - **kwargs: 其他参数,例如推理参数。 + + 返回: + - str: 生成的文本结果。 + """ + # 设置 AWS 凭证 os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( "AWS_ACCESS_KEY_ID", aws_access_key_id ) @@ -155,24 +230,24 @@ async def bedrock_complete_if_cache( "AWS_SESSION_TOKEN", aws_session_token ) - # Fix message history format + # 修复消息历史记录格式 messages = [] for history_message in history_messages: message = copy.copy(history_message) message["content"] = [{"text": message["content"]}] messages.append(message) - # Add user prompt + # 添加用户提示 messages.append({"role": "user", "content": [{"text": prompt}]}) - # Initialize Converse API arguments + # 初始化 Converse API 参数 args = {"modelId": model, "messages": messages} - # Define system prompt + # 定义系统提示 if system_prompt: args["system"] = [{"text": system_prompt}] - # Map and set up inference parameters + # 映射并设置推理参数 inference_params_map = { "max_tokens": "maxTokens", "top_p": "topP", @@ -187,6 +262,7 @@ async def bedrock_complete_if_cache( kwargs.pop(param) ) + # 处理缓存逻辑 hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: args_hash = compute_args_hash(model, messages) @@ -194,7 +270,7 @@ async def bedrock_complete_if_cache( if if_cache_return is not None: return if_cache_return["return"] - # Call model via Converse API + # 通过 Converse API 调用模型 session = aioboto3.Session() async with session.client("bedrock-runtime") as bedrock_async_client: try: @@ -202,6 +278,7 @@ async def bedrock_complete_if_cache( except Exception as e: raise BedrockError(e) + # 更新缓存(如果启用) if hashing_kv is not None: await hashing_kv.upsert( { @@ -214,9 +291,20 @@ async def bedrock_complete_if_cache( return response["output"]["message"]["content"][0]["text"] - @lru_cache(maxsize=1) def initialize_hf_model(model_name): + """ + 初始化Hugging Face模型和tokenizer。 + + 使用指定的模型名称初始化模型和tokenizer,并根据需要设置padding token。 + + 参数: + - model_name: 模型的名称。 + + 返回: + - hf_model: 初始化的Hugging Face模型。 + - hf_tokenizer: 初始化的Hugging Face tokenizer。 + """ hf_tokenizer = AutoTokenizer.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) @@ -232,6 +320,21 @@ def initialize_hf_model(model_name): async def hf_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 使用缓存的Hugging Face模型进行推理。 + + 如果缓存中存在相同的输入,则直接返回结果,否则使用指定的模型进行推理并将结果缓存。 + + 参数: + - model: 模型的名称。 + - prompt: 用户的输入提示。 + - system_prompt: 系统的提示(可选)。 + - history_messages: 历史消息列表(可选)。 + - **kwargs: 其他关键字参数,例如hashing_kv用于缓存存储。 + + 返回: + - response_text: 模型的响应文本。 + """ model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) @@ -297,32 +400,58 @@ async def hf_model_if_cache( async def ollama_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 异步函数,通过Olama模型生成回答,支持缓存机制以优化性能。 + + 参数: + model: 使用的模型名称。 + prompt: 用户的提问。 + system_prompt: 系统的提示,用于设定对话背景。 + history_messages: 历史对话消息,用于维持对话上下文。 + **kwargs: 其他参数,包括max_tokens, response_format, host, timeout等。 + + 返回: + 生成的模型回答。 + """ + # 移除不需要的参数,以符合Olama客户端的期望 kwargs.pop("max_tokens", None) kwargs.pop("response_format", None) host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) + # 初始化Olama异步客户端 ollama_client = ollama.AsyncClient(host=host, timeout=timeout) + + # 构建消息列表,首先添加系统提示(如果有) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) + # 获取哈希存储实例,用于缓存 hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + + # 将历史消息和当前用户提问添加到消息列表 messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + + # 如果提供了哈希存储,尝试从缓存中获取回答 if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] + # 如果缓存中没有回答,调用Olama模型生成回答 response = await ollama_client.chat(model=model, messages=messages, **kwargs) + # 提取生成的回答内容 result = response["message"]["content"] + # 如果使用了哈希存储,将新生成的回答存入缓存 if hashing_kv is not None: await hashing_kv.upsert({args_hash: {"return": result, "model": model}}) + # 返回生成的回答 return result @@ -335,8 +464,24 @@ def initialize_lmdeploy_pipeline( model_format="hf", quant_policy=0, ): + """ + 初始化lmdeploy管道,用于模型推理,带有缓存机制。 + + 参数: + model: 模型路径。 + tp: 张量并行度。 + chat_template: 聊天模板配置。 + log_level: 日志级别。 + model_format: 模型格式。 + quant_policy: 量化策略。 + + 返回: + 初始化的lmdeploy管道实例。 + """ + # 导入必要的模块和类 from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig + # 创建并配置lmdeploy管道 lmdeploy_pipe = pipeline( model_path=model, backend_config=TurbomindEngineConfig( @@ -347,6 +492,7 @@ def initialize_lmdeploy_pipeline( else None, log_level="WARNING", ) + # 返回配置好的管道实例 return lmdeploy_pipe @@ -361,39 +507,38 @@ async def lmdeploy_model_if_cache( **kwargs, ) -> str: """ - Args: - model (str): The path to the model. - It could be one of the following options: - - i) A local directory path of a turbomind model which is - converted by `lmdeploy convert` command or download - from ii) and iii). - - ii) The model_id of a lmdeploy-quantized model hosted - inside a model repo on huggingface.co, such as + 异步执行语言模型推理,支持缓存。 + + 该函数初始化 lmdeploy 管道进行模型推理,支持多种模型格式和量化策略。它处理输入的提示文本、系统提示和历史消息, + 并尝试从缓存中检索响应。如果未命中缓存,则生成响应并缓存结果以供将来使用。 + + 参数: + model (str): 模型路径。 + 可以是以下选项之一: + - i) 通过 `lmdeploy convert` 命令转换或从 ii) 和 iii) 下载的本地 turbomind 模型目录路径。 + - ii) 在 huggingface.co 上托管的 lmdeploy 量化模型的 model_id,例如 "InternLM/internlm-chat-20b-4bit", - "lmdeploy/llama2-chat-70b-4bit", etc. - - iii) The model_id of a model hosted inside a model repo - on huggingface.co, such as "internlm/internlm-chat-7b", - "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" - and so on. - chat_template (str): needed when model is a pytorch model on - huggingface.co, such as "internlm-chat-7b", - "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, - and when the model name of local path did not match the original model name in HF. - tp (int): tensor parallel - prompt (Union[str, List[str]]): input texts to be completed. - do_preprocess (bool): whether pre-process the messages. Default to - True, which means chat_template will be applied. - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. - Default to be False, which means greedy decoding will be applied. + "lmdeploy/llama2-chat-70b-4bit" 等。 + - iii) 在 huggingface.co 上托管的模型的 model_id,例如 + "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", + "baichuan-inc/Baichuan2-7B-Chat" 等。 + chat_template (str): 当模型是 huggingface.co 上的 PyTorch 模型时需要,例如 "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" 等,以及当本地路径的模型名称与 HF 中的原始模型名称不匹配时。 + tp (int): 张量并行度 + prompt (Union[str, List[str]]): 要完成的输入文本。 + do_preprocess (bool): 是否预处理消息。默认为 True,表示将应用 chat_template。 + skip_special_tokens (bool): 解码时是否移除特殊标记。默认为 True。 + do_sample (bool): 是否使用采样,否则使用贪心解码。默认为 False,表示使用贪心解码。 """ + # 导入 lmdeploy 及相关模块,如果未安装则抛出错误 try: import lmdeploy from lmdeploy import version_info, GenerationConfig except Exception: - raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") + raise ImportError("请在初始化 lmdeploy 后端之前安装 lmdeploy。") + # 提取并处理关键字参数 kwargs.pop("response_format", None) max_new_tokens = kwargs.pop("max_tokens", 512) tp = kwargs.pop("tp", 1) @@ -402,16 +547,18 @@ async def lmdeploy_model_if_cache( do_sample = kwargs.pop("do_sample", False) gen_params = kwargs + # 检查 lmdeploy 版本兼容性,确保支持 do_sample 参数 version = version_info if do_sample is not None and version < (0, 6, 0): raise RuntimeError( - "`do_sample` parameter is not supported by lmdeploy until " - f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" + "`do_sample` 参数在 lmdeploy v0.6.0 之前不受支持,当前使用的 lmdeploy 版本为 {}" + .format(lmdeploy.__version__) ) else: do_sample = True gen_params.update(do_sample=do_sample) + # 初始化 lmdeploy 管道 lmdeploy_pipe = initialize_lmdeploy_pipeline( model=model, tp=tp, @@ -421,25 +568,31 @@ async def lmdeploy_model_if_cache( log_level="WARNING", ) + # 构建消息列表 messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) + # 获取哈希存储对象 hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + + # 尝试从缓存中获取响应 if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] + # 配置生成参数 gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params, ) + # 生成响应 response = "" async for res in lmdeploy_pipe.generate( messages, @@ -450,14 +603,29 @@ async def lmdeploy_model_if_cache( ): response += res.response + # 缓存生成的响应 if hashing_kv is not None: await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) + return response + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 使用GPT-4o模型完成文本生成任务。 + + 参数: + - prompt: 用户输入的提示文本。 + - system_prompt: 系统级别的提示文本,用于指导模型生成。 + - history_messages: 历史对话消息,用于上下文理解。 + - **kwargs: 其他可变关键字参数。 + + 返回: + - 生成的文本结果。 + """ return await openai_complete_if_cache( "gpt-4o", prompt, @@ -466,10 +634,21 @@ async def gpt_4o_complete( **kwargs, ) - async def gpt_4o_mini_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 使用较小的GPT-4o模型完成文本生成任务。 + + 参数: + - prompt: 用户输入的提示文本。 + - system_prompt: 系统级别的提示文本,用于指导模型生成。 + - history_messages: 历史对话消息,用于上下文理解。 + - **kwargs: 其他可变关键字参数。 + + 返回: + - 生成的文本结果。 + """ return await openai_complete_if_cache( "gpt-4o-mini", prompt, @@ -478,10 +657,21 @@ async def gpt_4o_mini_complete( **kwargs, ) - async def azure_openai_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 使用Azure上的OpenAI模型完成文本生成任务。 + + 参数: + - prompt: 用户输入的提示文本。 + - system_prompt: 系统级别的提示文本,用于指导模型生成。 + - history_messages: 历史对话消息,用于上下文理解。 + - **kwargs: 其他可变关键字参数。 + + 返回: + - 生成的文本结果。 + """ return await azure_openai_complete_if_cache( "conversation-4o-mini", prompt, @@ -490,10 +680,21 @@ async def azure_openai_complete( **kwargs, ) - async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 使用Bedrock平台的特定模型完成文本生成任务。 + + 参数: + - prompt: 用户输入的提示文本。 + - system_prompt: 系统级别的提示文本,用于指导模型生成。 + - history_messages: 历史对话消息,用于上下文理解。 + - **kwargs: 其他可变关键字参数。 + + 返回: + - 生成的文本结果。 + """ return await bedrock_complete_if_cache( "anthropic.claude-3-haiku-20240307-v1:0", prompt, @@ -502,10 +703,21 @@ async def bedrock_complete( **kwargs, ) - async def hf_model_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 使用Hugging Face模型完成文本生成任务。 + + 参数: + - prompt: 用户输入的提示文本。 + - system_prompt: 系统级别的提示文本,用于指导模型生成。 + - history_messages: 历史对话消息,用于上下文理解。 + - **kwargs: 其他可变关键字参数,包括模型名称。 + + 返回: + - 生成的文本结果。 + """ model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await hf_model_if_cache( model_name, @@ -515,10 +727,21 @@ async def hf_model_complete( **kwargs, ) - async def ollama_model_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + """ + 使用Ollama模型完成文本生成任务。 + + 参数: + - prompt: 用户输入的提示文本。 + - system_prompt: 系统级别的提示文本,用于指导模型生成。 + - history_messages: 历史对话消息,用于上下文理解。 + - **kwargs: 其他可变关键字参数,包括模型名称。 + + 返回: + - 生成的文本结果。 + """ model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await ollama_model_if_cache( model_name, @@ -529,7 +752,9 @@ async def ollama_model_complete( ) +# 使用装饰器添加属性,如嵌入维度和最大令牌大小 @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +# 使用重试机制处理可能的速率限制、API连接和超时错误 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -541,6 +766,18 @@ async def openai_embedding( base_url: str = None, api_key: str = None, ) -> np.ndarray: + """ + 使用OpenAI模型生成文本嵌入。 + + 参数: + - texts: 需要生成嵌入的文本列表 + - model: 使用的模型名称 + - base_url: API的基础URL + - api_key: API密钥 + + 返回: + - 嵌入的NumPy数组 + """ if api_key: os.environ["OPENAI_API_KEY"] = api_key @@ -552,8 +789,9 @@ async def openai_embedding( ) return np.array([dp.embedding for dp in response.data]) - +# 使用装饰器添加属性,如嵌入维度和最大令牌大小 @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +# 使用重试机制处理可能的速率限制、API连接和超时错误 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -565,6 +803,18 @@ async def azure_openai_embedding( base_url: str = None, api_key: str = None, ) -> np.ndarray: + """ + 使用Azure OpenAI模型生成文本嵌入。 + + 参数: + - texts: 需要生成嵌入的文本列表 + - model: 使用的模型名称 + - base_url: API的基础URL + - api_key: API密钥 + + 返回: + - 嵌入的NumPy数组 + """ if api_key: os.environ["AZURE_OPENAI_API_KEY"] = api_key if base_url: @@ -581,7 +831,7 @@ async def azure_openai_embedding( ) return np.array([dp.embedding for dp in response.data]) - +# 使用重试机制处理可能的速率限制、API连接和超时错误 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -594,6 +844,19 @@ async def siliconcloud_embedding( max_token_size: int = 512, api_key: str = None, ) -> np.ndarray: + """ + 使用SiliconCloud模型生成文本嵌入。 + + 参数: + - texts: 需要生成嵌入的文本列表 + - model: 使用的模型名称 + - base_url: API的基础URL + - max_token_size: 最大令牌大小 + - api_key: API密钥 + + 返回: + - 嵌入的NumPy数组 + """ if api_key and not api_key.startswith("Bearer "): api_key = "Bearer " + api_key @@ -633,6 +896,22 @@ async def bedrock_embedding( aws_secret_access_key=None, aws_session_token=None, ) -> np.ndarray: + """ + 生成给定文本的嵌入向量。 + + 使用指定的模型对文本列表进行嵌入处理,支持Amazon Bedrock和Cohere模型。 + + 参数: + - texts: 需要嵌入的文本列表。 + - model: 使用的模型标识符,默认为"amazon.titan-embed-text-v2:0"。 + - aws_access_key_id: AWS访问密钥ID。 + - aws_secret_access_key: AWS秘密访问密钥。 + - aws_session_token: AWS会话令牌。 + + 返回: + - 嵌入向量的NumPy数组。 + """ + # 设置AWS环境变量 os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( "AWS_ACCESS_KEY_ID", aws_access_key_id ) @@ -643,11 +922,14 @@ async def bedrock_embedding( "AWS_SESSION_TOKEN", aws_session_token ) + # 创建aioboto3会话 session = aioboto3.Session() async with session.client("bedrock-runtime") as bedrock_async_client: + # 根据模型提供者进行不同的处理 if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: + # 根据模型版本构建请求体 if "v2" in model: body = json.dumps( { @@ -661,6 +943,7 @@ async def bedrock_embedding( else: raise ValueError(f"Model {model} is not supported!") + # 调用Bedrock模型 response = await bedrock_async_client.invoke_model( modelId=model, body=body, @@ -693,9 +976,22 @@ async def bedrock_embedding( async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: + """ + 使用Hugging Face模型生成给定文本的嵌入向量。 + + 参数: + - texts: 需要嵌入的文本列表。 + - tokenizer: Hugging Face的标记器实例。 + - embed_model: Hugging Face的嵌入模型实例。 + + 返回: + - 嵌入向量的NumPy数组。 + """ + # 对文本进行标记化处理 input_ids = tokenizer( texts, return_tensors="pt", padding=True, truncation=True ).input_ids + # 使用模型生成嵌入向量 with torch.no_grad(): outputs = embed_model(input_ids) embeddings = outputs.last_hidden_state.mean(dim=1) @@ -703,9 +999,22 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: + """ + 使用Ollama模型生成给定文本的嵌入向量。 + + 参数: + - texts: 需要嵌入的文本列表。 + - embed_model: 使用的嵌入模型标识符。 + - **kwargs: 传递给Ollama客户端的其他参数。 + + 返回: + - 嵌入向量的列表。 + """ embed_text = [] + # 创建Ollama客户端实例 ollama_client = ollama.Client(**kwargs) for text in texts: + # 调用模型生成嵌入向量 data = ollama_client.embeddings(model=embed_model, prompt=text) embed_text.append(data["embedding"]) @@ -798,4 +1107,4 @@ if __name__ == "__main__": result = await gpt_4o_mini_complete("How are you?") print(result) - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/lightrag/operate.py b/lightrag/operate.py index e86388d..bc6e212 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -28,20 +28,48 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS def chunking_by_token_size( - content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" + content: str, + overlap_token_size=128, # 重叠部分的token数量 + max_token_size=1024, # 每个chunk的最大token数量 + tiktoken_model="gpt-4o" # 使用的tokenizer模型名称 ): + """将长文本按照token数量切分成多个重叠的chunk + + Args: + content (str): 需要切分的原始文本内容 + overlap_token_size (int, optional): 相邻chunk之间的重叠token数. Defaults to 128. + max_token_size (int, optional): 每个chunk的最大token数. Defaults to 1024. + tiktoken_model (str, optional): 使用的tokenizer模型. Defaults to "gpt-4o". + + Returns: + list[dict]: 包含切分后的chunk列表,每个chunk包含以下字段: + - tokens: chunk实际包含的token数 + - content: chunk的文本内容 + - chunk_order_index: chunk的序号 + """ + # 使用指定的tokenizer模型将文本编码为tokens tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) + results = [] + # 按照指定的步长(max_token_size - overlap_token_size)遍历tokens + # for index, start in enumerate( range(0, len(tokens), max_token_size - overlap_token_size) ): + # 解码当前窗口范围内的tokens为文本 chunk_content = decode_tokens_by_tiktoken( - tokens[start : start + max_token_size], model_name=tiktoken_model + tokens[start : start + max_token_size], + model_name=tiktoken_model ) + + # 将chunk添加到结果列表中 results.append( { + # 计算实际token数量(最后一个chunk可能不足max_token_size) "tokens": min(max_token_size, len(tokens) - start), + # 去除chunk首尾的空白字符 "content": chunk_content.strip(), + # 记录chunk的序号 "chunk_order_index": index, } ) @@ -49,45 +77,94 @@ def chunking_by_token_size( async def _handle_entity_relation_summary( - entity_or_relation_name: str, - description: str, - global_config: dict, + entity_or_relation_name: str, # 实体名称或关系名称 + description: str, # 需要总结的描述文本 + global_config: dict, # 全局配置参数 ) -> str: + """对实体或关系的描述文本进行总结 + + 当描述文本的token数量超过指定阈值时,使用LLM对其进行总结,以减少token数量 + + Args: + entity_or_relation_name (str): 实体名称或关系名称(用于提示LLM) + description (str): 需要总结的描述文本 + global_config (dict): 包含LLM配置的全局参数字典,必须包含: + - llm_model_func: LLM调用函数 + - llm_model_max_token_size: LLM输入的最大token数 + - tiktoken_model_name: 使用的tokenizer模型名称 + - entity_summary_to_max_tokens: 总结后的最大token数 + + Returns: + str: 如果原文本较短则直接返回,否则返回LLM总结后的文本 + """ + # 从全局配置中获取必要的参数 use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] + # 计算描述文本的token数量 tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) - if len(tokens) < summary_max_tokens: # No need for summary + + # 如果token数量小于阈值,无需总结直接返回 + if len(tokens) < summary_max_tokens: return description + + # 获取总结用的提示模板 prompt_template = PROMPTS["summarize_entity_descriptions"] + + # 截取LLM能处理的最大token数量 use_description = decode_tokens_by_tiktoken( - tokens[:llm_max_tokens], model_name=tiktoken_model_name + tokens[:llm_max_tokens], + model_name=tiktoken_model_name ) + + # 构建提示上下文 context_base = dict( entity_name=entity_or_relation_name, + # 按字段分隔符分割描述文本 description_list=use_description.split(GRAPH_FIELD_SEP), ) + + # 格式化提示模板 use_prompt = prompt_template.format(**context_base) + + # 记录调试日志 logger.debug(f"Trigger summary: {entity_or_relation_name}") + + # 调用LLM生成总结 summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens) + return summary - async def _handle_single_entity_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], # 记录的属性列表 + chunk_key: str, # 当前chunk的唯一标识符 ): + """从记录属性中提取单个实体的信息 + + Args: + record_attributes (list[str]): 包含实体信息的属性列表 + chunk_key (str): 当前chunk的唯一标识符,用于追踪数据来源 + + Returns: + dict: 包含实体信息的字典,如果记录无效则返回None + """ + # 检查记录是否为有效的实体记录 if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None - # add this record as a node in the G + + # 提取并清理实体名称 entity_name = clean_str(record_attributes[1].upper()) - if not entity_name.strip(): + if not entity_name.strip(): # 如果实体名称为空,返回None return None + + # 提取并清理实体类型和描述 entity_type = clean_str(record_attributes[2].upper()) entity_description = clean_str(record_attributes[3]) - entity_source_id = chunk_key + entity_source_id = chunk_key # 记录数据来源 + + # 返回包含实体信息的字典 return dict( entity_name=entity_name, entity_type=entity_type, @@ -97,21 +174,37 @@ async def _handle_single_entity_extraction( async def _handle_single_relationship_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], # 记录的属性列表 + chunk_key: str, # 当前chunk的唯一标识符 ): + """从记录属性中提取单个关系的信息 + + Args: + record_attributes (list[str]): 包含关系信息的属性列表 + chunk_key (str): 当前chunk的唯一标识符,用于追踪数据来源 + + Returns: + dict: 包含关系信息的字典,如果记录无效则返回None + """ + # 检查记录是否为有效的关系记录 if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None - # add this record as edge + + # 提取并清理关系的源节点和目标节点 source = clean_str(record_attributes[1].upper()) target = clean_str(record_attributes[2].upper()) + + # 提取并清理关系的描述和关键词 edge_description = clean_str(record_attributes[3]) - edge_keywords = clean_str(record_attributes[4]) - edge_source_id = chunk_key + edge_source_id = chunk_key # 记录数据来源 + + # 提取并计算关系的权重,默认为1.0 weight = ( float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 ) + + # 返回包含关系信息的字典 return dict( src_id=source, tgt_id=target, @@ -121,25 +214,38 @@ async def _handle_single_relationship_extraction( source_id=edge_source_id, ) - async def _merge_nodes_then_upsert( - entity_name: str, - nodes_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + entity_name: str, # 实体名称 + nodes_data: list[dict], # 待合并的节点数据列表 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 + global_config: dict, # 全局配置参数 ): + """合并节点数据并更新到知识图谱中 + + Args: + entity_name (str): 实体名称 + nodes_data (list[dict]): 待合并的节点数据列表 + knowledge_graph_inst (BaseGraphStorage): 知识图谱实例 + global_config (dict): 全局配置参数 + + Returns: + dict: 更新后的节点数据 + """ already_entitiy_types = [] already_source_ids = [] already_description = [] + # 检查知识图谱中是否已存在该节点 already_node = await knowledge_graph_inst.get_node(entity_name) if already_node is not None: + # 如果存在,提取已有的实体类型、来源ID和描述 already_entitiy_types.append(already_node["entity_type"]) already_source_ids.extend( split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) ) already_description.append(already_node["description"]) + # 计算合并后的实体类型(选择出现频率最高的类型) entity_type = sorted( Counter( [dp["entity_type"] for dp in nodes_data] + already_entitiy_types @@ -147,42 +253,65 @@ async def _merge_nodes_then_upsert( key=lambda x: x[1], reverse=True, )[0][0] + + # 合并描述和来源ID description = GRAPH_FIELD_SEP.join( sorted(set([dp["description"] for dp in nodes_data] + already_description)) ) source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in nodes_data] + already_source_ids) ) + + # 使用LLM对合并后的描述进行总结 description = await _handle_entity_relation_summary( entity_name, description, global_config ) + + # 构建节点数据字典 node_data = dict( entity_type=entity_type, description=description, source_id=source_id, ) + + # 更新或插入节点到知识图谱中 await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) + node_data["entity_name"] = entity_name return node_data async def _merge_edges_then_upsert( - src_id: str, - tgt_id: str, - edges_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + src_id: str, # 源节点ID + tgt_id: str, # 目标节点ID + edges_data: list[dict], # 待合并的边数据列表 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 + global_config: dict, # 全局配置参数 ): + """合并边数据并更新到知识图谱中 + + Args: + src_id (str): 源节点ID + tgt_id (str): 目标节点ID + edges_data (list[dict]): 待合并的边数据列表 + knowledge_graph_inst (BaseGraphStorage): 知识图谱实例 + global_config (dict): 全局配置参数 + + Returns: + dict: 更新后的边数据 + """ already_weights = [] already_source_ids = [] already_description = [] already_keywords = [] + # 检查知识图谱中是否已存在该边 if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) + # 如果存在,提取已有的权重、来源ID、描述和关键词 already_weights.append(already_edge["weight"]) already_source_ids.extend( split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) @@ -192,7 +321,10 @@ async def _merge_edges_then_upsert( split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP]) ) + # 计算合并后的权重(累加) weight = sum([dp["weight"] for dp in edges_data] + already_weights) + + # 合并描述、关键词和来源ID description = GRAPH_FIELD_SEP.join( sorted(set([dp["description"] for dp in edges_data] + already_description)) ) @@ -202,6 +334,8 @@ async def _merge_edges_then_upsert( source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in edges_data] + already_source_ids) ) + + # 确保源节点和目标节点存在于知识图谱中 for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): await knowledge_graph_inst.upsert_node( @@ -212,9 +346,13 @@ async def _merge_edges_then_upsert( "entity_type": '"UNKNOWN"', }, ) + + # 使用LLM对合并后的描述进行总结 description = await _handle_entity_relation_summary( (src_id, tgt_id), description, global_config ) + + # 更新或插入边到知识图谱中 await knowledge_graph_inst.upsert_edge( src_id, tgt_id, @@ -235,19 +373,33 @@ async def _merge_edges_then_upsert( return edge_data - async def extract_entities( - chunks: dict[str, TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, - entity_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - global_config: dict, + chunks: dict[str, TextChunkSchema], # 文本块字典,键为chunk的唯一标识符 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 + entity_vdb: BaseVectorStorage, # 实体向量数据库 + relationships_vdb: BaseVectorStorage, # 关系向量数据库 + global_config: dict, # 全局配置参数 ) -> Union[BaseGraphStorage, None]: + """从文本块中提取实体和关系,并更新到知识图谱中 + + Args: + chunks: 待处理的文本块字典 + knowledge_graph_inst: 用于存储实体和关系的知识图谱实例 + entity_vdb: 用于存储实体向量的数据库 + relationships_vdb: 用于存储关系向量的数据库 + global_config: 包含LLM配置等全局参数 + + Returns: + 更新后的知识图谱实例,如果提取失败则返回None + """ + # 从全局配置中获取LLM函数和最大提取次数 use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] + # 将chunks字典转换为有序列表 ordered_chunks = list(chunks.items()) + # 获取提示模板和分隔符配置 entity_extract_prompt = PROMPTS["entity_extraction"] context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], @@ -258,27 +410,43 @@ async def extract_entities( continue_prompt = PROMPTS["entiti_continue_extraction"] if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] + # 初始化处理计数器 already_processed = 0 already_entities = 0 already_relations = 0 async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """处理单个文本块,提取实体和关系 + + Args: + chunk_key_dp: (chunk_key, chunk_data)元组 + + Returns: + (nodes_dict, edges_dict): 提取的实体和关系字典 + """ nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] + + # 构建初始提示并获取LLM响应 hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) final_result = await use_llm_func(hint_prompt) + # 记录对话历史 history = pack_user_ass_to_openai_messages(hint_prompt, final_result) + + # 多轮提取循环 for now_glean_index in range(entity_extract_max_gleaning): + # 继续提取新的实体和关系 glean_result = await use_llm_func(continue_prompt, history_messages=history) - history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) final_result += glean_result + if now_glean_index == entity_extract_max_gleaning - 1: break + # 询问是否需要继续提取 if_loop_result: str = await use_llm_func( if_loop_prompt, history_messages=history ) @@ -286,21 +454,30 @@ async def extract_entities( if if_loop_result != "yes": break + # 分割提取结果为独立记录 records = split_string_by_multi_markers( final_result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) + # 用于收集可能的实体和关系 maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) + + # 处理每条记录 for record in records: + # 提取括号内的属性内容 record = re.search(r"\((.*)\)", record) if record is None: continue record = record.group(1) + + # 分割属性列表 record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) + + # 尝试提取实体 if_entities = await _handle_single_entity_extraction( record_attributes, chunk_key ) @@ -308,6 +485,7 @@ async def extract_entities( maybe_nodes[if_entities["entity_name"]].append(if_entities) continue + # 尝试提取关系 if_relation = await _handle_single_relationship_extraction( record_attributes, chunk_key ) @@ -315,9 +493,13 @@ async def extract_entities( maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( if_relation ) + + # 更新处理计数 already_processed += 1 already_entities += len(maybe_nodes) already_relations += len(maybe_edges) + + # 显示处理进度 now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) ] @@ -328,11 +510,13 @@ async def extract_entities( ) return dict(maybe_nodes), dict(maybe_edges) - # use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings + # 并发处理所有文本块 results = await asyncio.gather( *[_process_single_content(c) for c in ordered_chunks] ) - print() # clear the progress bar + print() # 清除进度条 + + # 合并所有提取结果 maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) for m_nodes, m_edges in results: @@ -340,18 +524,24 @@ async def extract_entities( maybe_nodes[k].extend(v) for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) + + # 合并并更新实体到知识图谱 all_entities_data = await asyncio.gather( *[ _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) for k, v in maybe_nodes.items() ] ) + + # 合并并更新关系到知识图谱 all_relationships_data = await asyncio.gather( *[ _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) for k, v in maybe_edges.items() ] ) + + # 检查提取结果 if not len(all_entities_data): logger.warning("Didn't extract any entities, maybe your LLM is not working") return None @@ -361,6 +551,7 @@ async def extract_entities( ) return None + # 更新实体向量数据库 if entity_vdb is not None: data_for_vdb = { compute_mdhash_id(dp["entity_name"], prefix="ent-"): { @@ -371,6 +562,7 @@ async def extract_entities( } await entity_vdb.upsert(data_for_vdb) + # 更新关系向量数据库 if relationships_vdb is not None: data_for_vdb = { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { @@ -387,28 +579,45 @@ async def extract_entities( return knowledge_graph_inst - async def local_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, + query, # 用户的查询文本 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱存储实例 + entities_vdb: BaseVectorStorage, # 实体的向量数据库 + relationships_vdb: BaseVectorStorage, # 关系的向量数据库 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 原始文本块的存储 + query_param: QueryParam, # 查询参数配置 + global_config: dict, # 全局配置字典 ) -> str: + """本地查询函数,从知识图谱中检索相关信息并生成回答 + + Args: + query: 用户输入的查询文本 + knowledge_graph_inst: 用于存储和查询知识图谱的实例 + entities_vdb: 用于向量检索实体的数据库 + relationships_vdb: 用于向量检索关系的数据库 + text_chunks_db: 存储原始文本块的键值数据库 + query_param: 包含查询相关参数的配置对象 + global_config: 包含全局配置的字典 + + Returns: + str: 查询的响应文本 + """ context = None + # 从全局配置中获取语言模型函数 use_model_func = global_config["llm_model_func"] + # 构建关键词提取的提示 kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) + # 尝试解析语言模型返回的关键词JSON try: keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError: + # 第一次解析失败,清理响应文本后重试 try: result = ( result.replace(kw_prompt[:-1], "") @@ -421,10 +630,12 @@ async def local_query( keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) keywords = ", ".join(keywords) - # Handle parsing error except json.JSONDecodeError as e: - print(f"JSON parsing error: {e}") + # 解析彻底失败,返回错误响应 + print(f"JSON解析错误: {e}") return PROMPTS["fail_response"] + + # 如果成功提取到关键词,构建查询上下文 if keywords: context = await _build_local_query_context( keywords, @@ -433,10 +644,14 @@ async def local_query( text_chunks_db, query_param, ) + + # 如果只需要上下文,直接返回 if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] + + # 构建RAG系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type @@ -445,6 +660,8 @@ async def local_query( query, system_prompt=sys_prompt, ) + + # 清理语言模型的响应文本 if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -460,38 +677,56 @@ async def local_query( async def _build_local_query_context( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + query, # 查询关键词 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 + entities_vdb: BaseVectorStorage, # 实体向量数据库 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 + query_param: QueryParam, # 查询参数 ): + """构建本地查询的上下文信息 + + 基于查询关键词,从知识图谱和向量数据库中检索相关的实体、关系和原文信息, + 并将它们组织成结构化的上下文 + """ + # 从实体向量数据库中检索最相关的实体 results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return None + + # 并发获取检索到的实体节点数据 node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] ) if not all([n is not None for n in node_datas]): - logger.warning("Some nodes are missing, maybe the storage is damaged") + logger.warning("部分节点数据缺失,存储可能损坏") + + # 获取每个实体节点的度数(连接数) node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] ) + + # 合并节点信息,添加实体名称和排名 node_datas = [ {**n, "entity_name": k["entity_name"], "rank": d} for k, n, d in zip(results, node_datas, node_degrees) if n is not None - ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. + ] + + # 查找与实体最相关的文本单元和关系 use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst ) use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst ) + + # 记录使用的数据量 logger.info( - f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" + f"本地查询使用了 {len(node_datas)} 个实体, {len(use_relations)} 个关系, {len(use_text_units)} 个文本单元" ) + + # 构建实体信息的CSV格式数据 entites_section_list = [["id", "entity", "type", "description", "rank"]] for i, n in enumerate(node_datas): entites_section_list.append( @@ -505,6 +740,7 @@ async def _build_local_query_context( ) entities_context = list_of_list_to_csv(entites_section_list) + # 构建关系信息的CSV格式数据 relations_section_list = [ ["id", "source", "target", "description", "keywords", "weight", "rank"] ] @@ -522,10 +758,13 @@ async def _build_local_query_context( ) relations_context = list_of_list_to_csv(relations_section_list) + # 构建文本单元的CSV格式数据 text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) + + # 返回格式化的上下文信息 return f""" -----Entities----- ```csv @@ -541,45 +780,56 @@ async def _build_local_query_context( ``` """ - async def _find_most_related_text_unit_from_entities( - node_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], # 实体节点数据列表 + query_param: QueryParam, # 查询参数配置 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): + """查找与给定实体最相关的文本单元 + + 通过分析实体的一跳邻居和关系,找出最相关的原始文本单元 + """ + # 获取每个实体节点关联的文本单元ID text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in node_datas ] + + # 获取每个实体的边(关系)信息 edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] ) + + # 收集所有一跳邻居节点 all_one_hop_nodes = set() for this_edges in edges: if not this_edges: continue all_one_hop_nodes.update([e[1] for e in this_edges]) + # 获取一跳邻居节点的数据 all_one_hop_nodes = list(all_one_hop_nodes) all_one_hop_nodes_data = await asyncio.gather( *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes] ) - # Add null check for node data + # 构建一跳邻居节点的文本单元查找表 all_one_hop_text_units_lookup = { k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP])) for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data) - if v is not None and "source_id" in v # Add source_id check + if v is not None and "source_id" in v # 检查节点数据有效性 } + # 构建所有文本单元的查找表,包含顺序和关系计数信息 all_text_units_lookup = {} for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: if c_id in all_text_units_lookup: continue + # 计算文本单元与一跳邻居的关系数量 relation_counts = 0 - if this_edges: # Add check for None edges + if this_edges: # 检查边是否存在 for e in this_edges: if ( e[1] in all_one_hop_text_units_lookup @@ -587,15 +837,16 @@ async def _find_most_related_text_unit_from_entities( ): relation_counts += 1 + # 获取文本单元数据 chunk_data = await text_chunks_db.get_by_id(c_id) - if chunk_data is not None and "content" in chunk_data: # Add content check + if chunk_data is not None and "content" in chunk_data: # 检查内容有效性 all_text_units_lookup[c_id] = { "data": chunk_data, "order": index, "relation_counts": relation_counts, } - # Filter out None values and ensure data has content + # 过滤无效数据并确保包含内容 all_text_units = [ {"id": k, **v} for k, v in all_text_units_lookup.items() @@ -606,75 +857,99 @@ async def _find_most_related_text_unit_from_entities( logger.warning("No valid text units found") return [] + # 按顺序和关系数量排序 all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) + # 根据token数量限制截断文本单元列表 all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, ) + # 只返回文本单元数据 all_text_units = [t["data"] for t in all_text_units] return all_text_units async def _find_most_related_edges_from_entities( - node_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], # 实体节点数据列表 + query_param: QueryParam, # 查询参数配置 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): + """查找与给定实体最相关的边(关系) + + 获取实体的所有关系,并根据度数和权重进行排序 + """ + # 获取所有相关边 all_related_edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] ) + + # 收集所有唯一的边(确保边的方向一致) all_edges = set() for this_edges in all_related_edges: all_edges.update([tuple(sorted(e)) for e in this_edges]) all_edges = list(all_edges) + + # 获取边的详细信息和度数 all_edges_pack = await asyncio.gather( *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges] ) all_edges_degree = await asyncio.gather( *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges] ) + + # 合并边的信息 all_edges_data = [ {"src_tgt": k, "rank": d, **v} for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree) if v is not None ] + + # 按度数和权重排序 all_edges_data = sorted( all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True ) + + # 根据token数量限制截断边列表 all_edges_data = truncate_list_by_token_size( all_edges_data, key=lambda x: x["description"], max_token_size=query_param.max_token_for_global_context, ) return all_edges_data - - async def global_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, + query, # 用户查询文本 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 + entities_vdb: BaseVectorStorage, # 实体向量数据库 + relationships_vdb: BaseVectorStorage, # 关系向量数据库 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 + query_param: QueryParam, # 查询参数配置 + global_config: dict, # 全局配置 ) -> str: + """执行全局查询,基于关系的向量检索 + + 通过分析高层概念关键词,检索相关的关系和实体,构建查询上下文 + """ context = None + # 获取语言模型函数 use_model_func = global_config["llm_model_func"] + # 提取查询关键词 kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) + # 解析关键词JSON结果 try: keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError: + # 清理响应文本后重试解析 try: result = ( result.replace(kw_prompt[:-1], "") @@ -689,9 +964,10 @@ async def global_query( keywords = ", ".join(keywords) except json.JSONDecodeError as e: - # Handle parsing error - print(f"JSON parsing error: {e}") + print(f"JSON解析错误: {e}") return PROMPTS["fail_response"] + + # 基于关键词构建查询上下文 if keywords: context = await _build_global_query_context( keywords, @@ -702,11 +978,13 @@ async def global_query( query_param, ) + # 如果只需要上下文,直接返回 if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] + # 构建系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type @@ -715,6 +993,8 @@ async def global_query( query, system_prompt=sys_prompt, ) + + # 清理响应文本 if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -730,50 +1010,70 @@ async def global_query( async def _build_global_query_context( - keywords, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + keywords, # 查询关键词 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 + entities_vdb: BaseVectorStorage, # 实体向量数据库 + relationships_vdb: BaseVectorStorage, # 关系向量数据库 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 + query_param: QueryParam, # 查询参数 ): + """构建全局查询的上下文信息 + + 基于关系的向量检索,找出最相关的关系、实体和文本单元 + """ + # 从关系向量数据库检索相关关系 results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): return None + # 获取关系的详细数据 edge_datas = await asyncio.gather( *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] ) + # 检查数据完整性 if not all([n is not None for n in edge_datas]): - logger.warning("Some edges are missing, maybe the storage is damaged") + logger.warning("部分边数据缺失,存储可能损坏") + + # 获取关系的度数 edge_degree = await asyncio.gather( *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results] ) + + # 合并关系信息 edge_datas = [ {"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v} for k, v, d in zip(results, edge_datas, edge_degree) if v is not None ] + + # 按度数和权重排序 edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) + + # 根据token限制截断关系列表 edge_datas = truncate_list_by_token_size( edge_datas, key=lambda x: x["description"], max_token_size=query_param.max_token_for_global_context, ) + # 获取相关的实体和文本单元 use_entities = await _find_most_related_entities_from_relationships( edge_datas, query_param, knowledge_graph_inst ) use_text_units = await _find_related_text_unit_from_relationships( edge_datas, query_param, text_chunks_db, knowledge_graph_inst ) + + # 记录使用的数据量 logger.info( - f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" + f"全局查询使用了 {len(use_entities)} 个实体, {len(edge_datas)} 个关系, {len(use_text_units)} 个文本单元" ) + + # 构建关系的CSV格式数据 relations_section_list = [ ["id", "source", "target", "description", "keywords", "weight", "rank"] ] @@ -791,6 +1091,7 @@ async def _build_global_query_context( ) relations_context = list_of_list_to_csv(relations_section_list) + # 构建实体的CSV格式数据 entites_section_list = [["id", "entity", "type", "description", "rank"]] for i, n in enumerate(use_entities): entites_section_list.append( @@ -804,11 +1105,13 @@ async def _build_global_query_context( ) entities_context = list_of_list_to_csv(entites_section_list) + # 构建文本单元的CSV格式数据 text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) + # 返回格式化的上下文信息 return f""" -----Entities----- ```csv @@ -824,29 +1127,38 @@ async def _build_global_query_context( ``` """ - async def _find_most_related_entities_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], # 关系数据列表 + query_param: QueryParam, # 查询参数配置 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): + """从关系中找出最相关的实体节点 + + 通过分析关系的源节点和目标节点,获取所有相关实体的详细信息 + """ + # 收集所有关系中涉及的实体名称 entity_names = set() for e in edge_datas: entity_names.add(e["src_id"]) entity_names.add(e["tgt_id"]) + # 并发获取所有实体节点的详细数据 node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] ) + # 并发获取所有实体节点的度数(连接数) node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names] ) + + # 合并实体信息,添加实体名称和排名 node_datas = [ {**n, "entity_name": k, "rank": d} for k, n, d in zip(entity_names, node_datas, node_degrees) ] + # 根据token数量限制截断实体列表 node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"], @@ -857,18 +1169,25 @@ async def _find_most_related_entities_from_relationships( async def _find_related_text_unit_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], # 关系数据列表 + query_param: QueryParam, # 查询参数配置 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): + """从关系中找出相关的文本单元 + + 通过分析关系的源文本,获取所有相关的原始文本单元 + """ + # 从每个关系中提取文本单元ID text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in edge_datas ] + # 构建文本单元查找表,记录顺序信息 all_text_units_lookup = {} + # 获取每个文本单元的详细数据 for index, unit_list in enumerate(text_units): for c_id in unit_list: if c_id not in all_text_units_lookup: @@ -877,46 +1196,64 @@ async def _find_related_text_unit_from_relationships( "order": index, } + # 检查数据完整性 if any([v is None for v in all_text_units_lookup.values()]): logger.warning("Text chunks are missing, maybe the storage is damaged") + + # 过滤无效数据并添加ID信息 all_text_units = [ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None ] + + # 按出现顺序排序 all_text_units = sorted(all_text_units, key=lambda x: x["order"]) + + # 根据token数量限制截断文本单元列表 all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, ) + + # 只返回文本单元数据 all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units] return all_text_units async def hybrid_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, + query, # 用户查询文本 + knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 + entities_vdb: BaseVectorStorage, # 实体向量数据库 + relationships_vdb: BaseVectorStorage, # 关系向量数据库 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 + query_param: QueryParam, # 查询参数配置 + global_config: dict, # 全局配置 ) -> str: + """混合查询函数,结合局部和全局查询的优势 + + 同时提取高层和低层关键词,分别构建上下文后合并 + """ + # 初始化上下文和模型函数 low_level_context = None high_level_context = None use_model_func = global_config["llm_model_func"] + # 构建关键词提取提示 kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) + # 获取语言模型响应 result = await use_model_func(kw_prompt) try: + # 尝试解析关键词JSON keywords_data = json.loads(result) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) hl_keywords = ", ".join(hl_keywords) ll_keywords = ", ".join(ll_keywords) except json.JSONDecodeError: + # 清理响应文本后重试解析 try: result = ( result.replace(kw_prompt[:-1], "") @@ -930,11 +1267,11 @@ async def hybrid_query( ll_keywords = keywords_data.get("low_level_keywords", []) hl_keywords = ", ".join(hl_keywords) ll_keywords = ", ".join(ll_keywords) - # Handle parsing error except json.JSONDecodeError as e: print(f"JSON parsing error: {e}") return PROMPTS["fail_response"] + # 如果有低层关键词,构建局部查询上下文 if ll_keywords: low_level_context = await _build_local_query_context( ll_keywords, @@ -944,6 +1281,7 @@ async def hybrid_query( query_param, ) + # 如果有高层关键词,构建全局查询上下文 if hl_keywords: high_level_context = await _build_global_query_context( hl_keywords, @@ -954,13 +1292,16 @@ async def hybrid_query( query_param, ) + # 合并两种上下文 context = combine_contexts(high_level_context, low_level_context) + # 如果只需要上下文,直接返回 if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] + # 构建系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type @@ -969,6 +1310,8 @@ async def hybrid_query( query, system_prompt=sys_prompt, ) + + # 清理响应文本 if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -983,27 +1326,33 @@ async def hybrid_query( def combine_contexts(high_level_context, low_level_context): - # Function to extract entities, relationships, and sources from context strings - + """合并高层和低层上下文 + + 从两种上下文中提取并合并实体、关系和来源信息 + """ + # 定义从上下文字符串中提取各部分的内部函数 def extract_sections(context): + # 使用正则表达式匹配实体部分 entities_match = re.search( r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL ) + # 匹配关系部分 relationships_match = re.search( r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL ) + # 匹配来源部分 sources_match = re.search( r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL ) + # 提取匹配内容或返回空字符串 entities = entities_match.group(1) if entities_match else "" relationships = relationships_match.group(1) if relationships_match else "" sources = sources_match.group(1) if sources_match else "" return entities, relationships, sources - # Extract sections from both contexts - + # 从高层上下文提取内容 if high_level_context is None: warnings.warn( "High Level context is None. Return empty High entity/relationship/source" @@ -1012,6 +1361,7 @@ def combine_contexts(high_level_context, low_level_context): else: hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) + # 从低层上下文提取内容 if low_level_context is None: warnings.warn( "Low Level context is None. Return empty Low entity/relationship/source" @@ -1020,18 +1370,18 @@ def combine_contexts(high_level_context, low_level_context): else: ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) - # Combine and deduplicate the entities + # 合并并去重实体信息 combined_entities = process_combine_contexts(hl_entities, ll_entities) - # Combine and deduplicate the relationships + # 合并并去重关系信息 combined_relationships = process_combine_contexts( hl_relationships, ll_relationships ) - # Combine and deduplicate the sources + # 合并并去重来源信息 combined_sources = process_combine_contexts(hl_sources, ll_sources) - # Format the combined context + # 返回格式化的合并上下文 return f""" -----Entities----- ```csv @@ -1049,28 +1399,48 @@ def combine_contexts(high_level_context, low_level_context): async def naive_query( - query, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, + query, # 用户查询文本 + chunks_vdb: BaseVectorStorage, # 文本块向量数据库 + text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 + query_param: QueryParam, # 查询参数配置 + global_config: dict, # 全局配置 ): + """简单查询函数,直接基于文本块的向量检索 + + 不使用知识图谱,仅通过向量相似度检索相关文本块 + """ + # 获取语言模型函数 use_model_func = global_config["llm_model_func"] + + # 从向量数据库检索相关文本块 results = await chunks_vdb.query(query, top_k=query_param.top_k) if not len(results): return PROMPTS["fail_response"] + + # 获取检索结果的ID列表 chunks_ids = [r["id"] for r in results] + + # 从文本块存储获取完整内容 chunks = await text_chunks_db.get_by_ids(chunks_ids) + # 根据token数量限制截断文本块列表 maybe_trun_chunks = truncate_list_by_token_size( chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, ) + + # 记录截断信息 logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") + + # 将文本块内容拼接成一个字符串 section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) + + # 如果只需要上下文,直接返回拼接的文本 if query_param.only_need_context: return section + + # 构建系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=section, response_type=query_param.response_type @@ -1080,6 +1450,7 @@ async def naive_query( system_prompt=sys_prompt, ) + # 清理响应文本 if len(response) > len(sys_prompt): response = ( response[len(sys_prompt) :] @@ -1092,4 +1463,4 @@ async def naive_query( .strip() ) - return response + return response \ No newline at end of file diff --git a/lightrag/storage.py b/lightrag/storage.py index 9a4c3d4..ff828df 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -20,27 +20,63 @@ from .base import ( BaseVectorStorage, ) - @dataclass class JsonKVStorage(BaseKVStorage): + """ + 基于JSON文件的键值存储实现类 + 继承自BaseKVStorage,提供基本的键值存储功能 + 数据以JSON格式保存在文件系统中 + """ + def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + """ + 初始化方法,在对象创建后自动调用 + - 设置工作目录和文件路径 + - 加载已存在的JSON数据 + """ + working_dir = self.global_config["working_dir"] # 从全局配置获取工作目录 + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") # 构建JSON文件完整路径 + self._data = load_json(self._file_name) or {} # 加载JSON文件,如果不存在则初始化为空字典 + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") # 记录加载数据的数量 async def all_keys(self) -> list[str]: + """ + 获取存储中的所有键 + 返回值: + list[str]: 包含所有键的列表 + """ return list(self._data.keys()) async def index_done_callback(self): + """ + 索引完成后的回调函数 + 将当前内存中的数据写入JSON文件 + """ write_json(self._data, self._file_name) async def get_by_id(self, id): + """ + 通过ID获取单个数据 + 参数: + id: 要查询的数据ID + 返回值: + 查找到的数据,如果不存在则返回None + """ return self._data.get(id, None) async def get_by_ids(self, ids, fields=None): + """ + 批量获取多个ID的数据 + 参数: + ids: ID列表 + fields: 可选,指定要返回的字段列表 + 返回值: + list: 包含查询结果的列表,每个元素对应一个ID的数据 + """ if fields is None: + # 如果未指定字段,返回完整数据 return [self._data.get(id, None) for id in ids] + # 如果指定了字段,只返回指定的字段 return [ ( {k: v for k, v in self._data[id].items() if k in fields} @@ -51,38 +87,80 @@ class JsonKVStorage(BaseKVStorage): ] async def filter_keys(self, data: list[str]) -> set[str]: + """ + 过滤出不存在于存储中的键 + 参数: + data: 要检查的键列表 + 返回值: + set[str]: 不存在的键集合 + """ return set([s for s in data if s not in self._data]) async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - return left_data + """ + 更新或插入数据 + 参数: + data: 要更新/插入的数据字典,格式为 {id: {字段: 值}} + 返回值: + dict: 实际插入的新数据(不包含更新的数据) + """ + left_data = {k: v for k, v in data.items() if k not in self._data} # 筛选出新数据 + self._data.update(left_data) # 更新存储 + return left_data # 返回新插入的数据 async def drop(self): + """ + 清空所有数据 + 将内存中的数据字典重置为空 + """ self._data = {} @dataclass class NanoVectorDBStorage(BaseVectorStorage): + """ + 向量数据库存储实现类 + 基于NanoVectorDB实现向量存储和检索功能 + 支持向量的增删改查操作 + """ + + # 余弦相似度阈值,用于过滤搜索结果 cosine_better_than_threshold: float = 0.2 def __post_init__(self): + """ + 初始化方法,在对象创建后自动调用 + 设置存储文件路径、批处理大小,并初始化向量数据库客户端 + """ + # 构建向量数据库存储文件路径 self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) + # 设置批处理大小 self._max_batch_size = self.global_config["embedding_batch_num"] + # 初始化向量数据库客户端 self._client = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name ) + # 从配置中获取相似度阈值 self.cosine_better_than_threshold = self.global_config.get( "cosine_better_than_threshold", self.cosine_better_than_threshold ) async def upsert(self, data: dict[str, dict]): + """ + 更新或插入向量数据 + 参数: + data: 包含向量数据的字典,格式为 {id: {字段: 值}} + 返回值: + list: 插入结果 + """ logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") return [] + + # 准备数据,提取元数据字段 list_data = [ { "__id__": k, @@ -90,28 +168,49 @@ class NanoVectorDBStorage(BaseVectorStorage): } for k, v in data.items() ] + + # 提取内容并分批处理 contents = [v["content"] for v in data.values()] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] + + # 并行计算向量嵌入 embeddings_list = await asyncio.gather( *[self.embedding_func(batch) for batch in batches] ) embeddings = np.concatenate(embeddings_list) + + # 将向量添加到数据中 for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] + + # 执行更新/插入操作 results = self._client.upsert(datas=list_data) return results async def query(self, query: str, top_k=5): + """ + 查询最相似的向量 + 参数: + query: 查询文本 + top_k: 返回的最相似结果数量 + 返回值: + list: 包含相似度结果的列表 + """ + # 计算查询文本的向量表示 embedding = await self.embedding_func([query]) embedding = embedding[0] + + # 执行向量检索 results = self._client.query( query=embedding, top_k=top_k, better_than_threshold=self.cosine_better_than_threshold, ) + + # 格式化返回结果 results = [ {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results ] @@ -119,12 +218,20 @@ class NanoVectorDBStorage(BaseVectorStorage): @property def client_storage(self): + """获取底层存储对象""" return getattr(self._client, "_NanoVectorDB__storage") async def delete_entity(self, entity_name: str): + """ + 删除指定实体 + 参数: + entity_name: 要删除的实体名称 + """ try: + # 计算实体ID entity_id = [compute_mdhash_id(entity_name, prefix="ent-")] + # 检查并删除实体 if self._client.get(entity_id): self._client.delete(entity_id) logger.info(f"Entity {entity_name} have been deleted.") @@ -134,7 +241,13 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error while deleting entity {entity_name}: {e}") async def delete_relation(self, entity_name: str): + """ + 删除与指定实体相关的所有关系 + 参数: + entity_name: 实体名称 + """ try: + # 查找所有相关关系 relations = [ dp for dp in self.client_storage["data"] @@ -142,6 +255,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ] ids_to_delete = [relation["__id__"] for relation in relations] + # 执行删除操作 if ids_to_delete: self._client.delete(ids_to_delete) logger.info( @@ -155,19 +269,38 @@ class NanoVectorDBStorage(BaseVectorStorage): ) async def index_done_callback(self): + """索引完成后的回调函数,保存数据到存储文件""" self._client.save() @dataclass class NetworkXStorage(BaseGraphStorage): + """ + 基于NetworkX的图存储实现类 + 提供图数据的存储、读取和操作功能 + """ + @staticmethod def load_nx_graph(file_name) -> nx.Graph: + """ + 从文件加载图数据 + 参数: + file_name: GraphML文件路径 + 返回值: + nx.Graph: 加载的图对象,如果文件不存在返回None + """ if os.path.exists(file_name): return nx.read_graphml(file_name) return None @staticmethod def write_nx_graph(graph: nx.Graph, file_name): + """ + 将图数据写入文件 + 参数: + graph: 要保存的图对象 + file_name: 保存路径 + """ logger.info( f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" ) @@ -175,40 +308,51 @@ class NetworkXStorage(BaseGraphStorage): @staticmethod def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Return the largest connected component of the graph, with nodes and edges sorted in a stable way. + """ + 获取图的最大连通分量,并确保节点和边的顺序稳定 + 参考: https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py + 参数: + graph: 输入图 + 返回值: + nx.Graph: 处理后的稳定图 """ from graspologic.utils import largest_connected_component graph = graph.copy() graph = cast(nx.Graph, largest_connected_component(graph)) + # 对节点标签进行标准化处理 node_mapping = { node: html.unescape(node.upper().strip()) for node in graph.nodes() - } # type: ignore + } graph = nx.relabel_nodes(graph, node_mapping) return NetworkXStorage._stabilize_graph(graph) @staticmethod def _stabilize_graph(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Ensure an undirected graph with the same relationships will always be read the same way. """ + 确保无向图的关系始终以相同的方式读取 + 参数: + graph: 输入图 + 返回值: + nx.Graph: 稳定化后的图 + """ + # 根据图的类型创建新图 fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() + # 对节点进行排序 sorted_nodes = graph.nodes(data=True) sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) + # 添加排序后的节点 fixed_graph.add_nodes_from(sorted_nodes) edges = list(graph.edges(data=True)) + # 对于无向图,确保边的源节点和目标节点有固定顺序 if not graph.is_directed(): - def _sort_source_target(edge): source, target, edge_data = edge if source > target: - temp = source - source = target - target = temp + source, target = target, source return source, target, edge_data edges = [_sort_source_target(edge) for edge in edges] @@ -216,12 +360,18 @@ class NetworkXStorage(BaseGraphStorage): def _get_edge_key(source: Any, target: Any) -> str: return f"{source} -> {target}" + # 对边进行排序 edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) - fixed_graph.add_edges_from(edges) return fixed_graph def __post_init__(self): + """ + 初始化方法 + - 设置图存储文件路径 + - 加载已存在的图数据 + - 初始化节点嵌入算法 + """ self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) @@ -236,46 +386,56 @@ class NetworkXStorage(BaseGraphStorage): } async def index_done_callback(self): + """索引完成后的回调,保存图数据到文件""" NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: + """检查节点是否存在""" return self._graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """检查边是否存在""" return self._graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> Union[dict, None]: + """获取节点数据""" return self._graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: + """获取节点的度""" return self._graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """获取边的度(源节点度 + 目标节点度)""" return self._graph.degree(src_id) + self._graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: + """获取边的数据""" return self._graph.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str): + """获取节点的所有边""" if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id)) return None async def upsert_node(self, node_id: str, node_data: dict[str, str]): + """更新或插入节点""" self._graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): + """更新或插入边""" self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str): """ - Delete a node from the graph based on the specified node_id. - - :param node_id: The node_id to delete + 删除指定的节点 + 参数: + node_id: 要删除的节点ID """ if self._graph.has_node(node_id): self._graph.remove_node(node_id) @@ -284,12 +444,23 @@ class NetworkXStorage(BaseGraphStorage): logger.warning(f"Node {node_id} not found in the graph for deletion.") async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + """ + 使用指定算法进行节点嵌入 + 参数: + algorithm: 嵌入算法名称 + 返回值: + tuple: (嵌入向量数组, 节点ID列表) + """ if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() - # @TODO: NOT USED async def _node2vec_embed(self): + """ + 使用node2vec算法进行节点嵌入(未使用) + 返回值: + tuple: (嵌入向量数组, 节点ID列表) + """ from graspologic import embed embeddings, nodes = embed.node2vec_embed( @@ -298,4 +469,4 @@ class NetworkXStorage(BaseGraphStorage): ) nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] - return embeddings, nodes_ids + return embeddings, nodes_ids \ No newline at end of file