""" LightRAG - 轻量级检索增强生成系统 该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建 """ import asyncio import os from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial from typing import Type, cast # 导入LLM相关功能 from .llm import ( gpt_4o_mini_complete, # GPT模型完成功能 openai_embedding, # OpenAI文本嵌入功能 ) # 导入核心操作功能 from .operate import ( chunking_by_token_size, # 文本分块 extract_entities, # 实体提取 local_query, # 本地查询 global_query, # 全局查询 hybrid_query, # 混合查询 naive_query, # 简单查询 ) # 导入存储实现 from .storage import ( JsonKVStorage, # JSON键值存储 NanoVectorDBStorage, # 向量数据库存储 NetworkXStorage, # 图数据库存储 ) from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现 # 未来可能的图数据库集成 # from .kg.ArangoDB_impl import ( # GraphStorage as ArangoDBStorage # ) # 导入工具函数 from .utils import ( EmbeddingFunc, # 嵌入函数类型 compute_mdhash_id, # 计算MD5哈希ID limit_async_func_call, # 限制异步函数调用 convert_response_to_json, # 响应转JSON logger, # 日志记录器 set_logger, # 设置日志 ) # 导入基类 from .base import ( BaseGraphStorage, # 图存储基类 BaseKVStorage, # 键值存储基类 BaseVectorStorage, # 向量存储基类 StorageNameSpace, # 存储命名空间 QueryParam, # 查询参数 ) def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ 获取或创建事件循环 如果当前线程没有事件循环,则创建一个新的 返回值: asyncio.AbstractEventLoop: 事件循环实例 """ try: return asyncio.get_event_loop() except RuntimeError: logger.info("Creating a new event loop in main thread.") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop @dataclass class LightRAG: """ 轻量级检索增强生成(LightRAG)系统的主类 实现了文档的存储、检索、知识图谱构建和问答功能 """ # 工作目录配置,用存储所有缓存文件 working_dir: str = field( default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) # 知识图谱存储类型,默认使用NetworkX实现 kg: str = field(default="NetworkXStorage") # 日志级别设置 current_log_level = logger.level log_level: str = field(default=current_log_level) # 文本分块参数配置 chunk_token_size: int = 1200 # 每个文本块的目标token数 chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数 tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称 # 实体提取参数 entity_extract_max_gleaning: int = 1 # 最大实体提取次数 entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数 # 节点嵌入配置 node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择 node2vec_params: dict = field( default_factory=lambda: { "dimensions": 1536, # 嵌入向量维度 "num_walks": 10, # 每个节点的随机游走次数 "walk_length": 40, # 每次随机游走的长度 "window_size": 2, # 上下文窗口大小 "iterations": 3, # 训练迭代次数 "random_seed": 3, # 随机种子 } ) # 文本嵌入配置 embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型 embedding_batch_num: int = 32 # 批处理大小 embedding_func_max_async: int = 16 # 最大并发请求数 # 语言模型配置 llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型 llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称 llm_model_max_token_size: int = 32768 # 模型最大token限制 llm_model_max_async: int = 16 # 最大并发请求数 llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数 # 存储配置 key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类 vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类 vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数 enable_llm_cache: bool = True # 是否启用语言模型缓存 # 扩展配置 addon_params: dict = field(default_factory=dict) # 附加参数 convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数 def __post_init__(self): """ 初始化方法,在对象创建后自动调用 负责设置日志、初始化存储系统和配置各种功能组件 """ # 配置日志系统 log_file = os.path.join(self.working_dir, "lightrag.log") set_logger(log_file) logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") # 记录初始化参数 _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") # 根据配置选择图存储实现类 self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ self.kg ] # 确保工作目录存在 if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) # 初始化文档存储系统 self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", global_config=asdict(self) ) # 初始化文本块存储系统 self.text_chunks = self.key_string_value_json_storage_cls( namespace="text_chunks", global_config=asdict(self) ) # 初始化语言模型响应缓存(如果启用) self.llm_response_cache = ( self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self) ) if self.enable_llm_cache else None ) # 初始化实体关系图存储 self.chunk_entity_relation_graph = self.graph_storage_cls( namespace="chunk_entity_relation", global_config=asdict(self) ) # 配置嵌入函数的并发限制 self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) # 初始化向量数据库存储系统 # 用于存储实体的向量表示 self.entities_vdb = self.vector_db_storage_cls( namespace="entities", global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) # 用于存储关系的向量表示 self.relationships_vdb = self.vector_db_storage_cls( namespace="relationships", global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) # 用于存储文本块的向量表示 self.chunks_vdb = self.vector_db_storage_cls( namespace="chunks", global_config=asdict(self), embedding_func=self.embedding_func, ) # 配置语言模型函数的并发限制和缓存 self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, hashing_kv=self.llm_response_cache, **self.llm_model_kwargs, ) ) def _get_storage_class(self) -> Type[BaseGraphStorage]: """ 获取图存储类的实现 根据配置选择合适的图存储后端(Neo4J或NetworkX) 返回值: Type[BaseGraphStorage]: 图存储类 """ return { "Neo4JStorage": Neo4JStorage, "NetworkXStorage": NetworkXStorage, } def insert(self, string_or_strings): """ 同步方式插入文档 将字符串或字符串列表插入到系统中进行处理 参数: string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容 """ loop = always_get_an_event_loop() return loop.run_until_complete(self.ainsert(string_or_strings)) async def ainsert(self, string_or_strings): """ 异步方式插入文档 处理文档内容,包括分块、实体提取和向量化存储 参数: string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容 """ try: # 确保输入是列表形式 if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] # 为每个文档生成唯一ID并创建文档字典 new_docs = { compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} for c in string_or_strings } # 过滤掉已存在的文档 _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if not len(new_docs): logger.warning("All docs are already in the storage") return logger.info(f"[New Docs] inserting {len(new_docs)} docs") # 处理文档分块 inserting_chunks = {} for doc_key, doc in new_docs.items(): # 对每个文档进行分块,并为每个块生成唯一ID chunks = { compute_mdhash_id(dp["content"], prefix="chunk-"): { **dp, "full_doc_id": doc_key, } for dp in chunking_by_token_size( doc["content"], overlap_token_size=self.chunk_overlap_token_size, max_token_size=self.chunk_token_size, tiktoken_model=self.tiktoken_model_name, ) } inserting_chunks.update(chunks) # 过滤掉已存在的文本块 _add_chunk_keys = await self.text_chunks.filter_keys( list(inserting_chunks.keys()) ) inserting_chunks = { k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } if not len(inserting_chunks): logger.warning("All chunks are already in the storage") return logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") # 将新的文本块插入向量数据库 await self.chunks_vdb.upsert(inserting_chunks) # 提取实体和关系并更新知识图谱 logger.info("[Entity Extraction]...") maybe_new_kg = await extract_entities( inserting_chunks, knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, global_config=asdict(self), ) if maybe_new_kg is None: logger.warning("No new entities and relationships found") return self.chunk_entity_relation_graph = maybe_new_kg # 更新文档和文本块存储 await self.full_docs.upsert(new_docs) await self.text_chunks.upsert(inserting_chunks) finally: # 完成插入后执行清理工作 await self._insert_done() async def _insert_done(self): """ 插入操作完成后的回调函数 负责更新所有存储实例的索引并保存状态 """ tasks = [] # 遍历所有需要执行回调的存储实例 for storage_inst in [ self.full_docs, # 完整文档存储 self.text_chunks, # 文本块存储 self.llm_response_cache, # LLM响应缓存 self.entities_vdb, # 实体向量数据库 self.relationships_vdb, # 关系向量数据库 self.chunks_vdb, # 文本块向量���据库 self.chunk_entity_relation_graph, # 实体关系图 ]: if storage_inst is None: continue # 将每个存储实例的回调任务添加到任务列表 tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) # 并发执行所有回调任务 await asyncio.gather(*tasks) def query(self, query: str, param: QueryParam = QueryParam()): """ 同步方式执行查询 参数: query: 查询文本 param: 查询参数配置 返回值: 查询结果 """ loop = always_get_an_event_loop() return loop.run_until_complete(self.aquery(query, param)) async def aquery(self, query: str, param: QueryParam = QueryParam()): """ 异步方式执行查询 支持多种查询模式:本地查询、全局查询、混合查询和简单查询 参数: query: 查询文本 param: 查询参数配置 返回值: 查询结果 """ # 根据查询模式选择相应的查询方法 if param.mode == "local": # 本地查询:主要基于局部上下文 response = await local_query( query, self.chunk_entity_relation_graph, self.entities_vdb, self.relationships_vdb, self.text_chunks, param, asdict(self), ) elif param.mode == "global": # 全局查询:考虑整个知识图谱 response = await global_query( query, self.chunk_entity_relation_graph, self.entities_vdb, self.relationships_vdb, self.text_chunks, param, asdict(self), ) elif param.mode == "hybrid": # 混合查询:结合局部和全局信息 response = await hybrid_query( query, self.chunk_entity_relation_graph, self.entities_vdb, self.relationships_vdb, self.text_chunks, param, asdict(self), ) elif param.mode == "naive": # 简单查询:直接基于文本相似度 response = await naive_query( query, self.chunks_vdb, self.text_chunks, param, asdict(self), ) else: raise ValueError(f"Unknown mode {param.mode}") # 执行查询完成后的清理工作 await self._query_done() return response async def _query_done(self): """ 查询操作完成后的回调函数 主要用于更新LLM响应缓存的状态 """ tasks = [] # 目前只需要处理LLM响应缓存的回调 for storage_inst in [self.llm_response_cache]: if storage_inst is None: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) # 并发执行所有回调任务 await asyncio.gather(*tasks) def delete_by_entity(self, entity_name: str): """ 同步方式删除指定实体 参数: entity_name: 要删除的实体名称 """ loop = always_get_an_event_loop() return loop.run_until_complete(self.adelete_by_entity(entity_name)) async def adelete_by_entity(self, entity_name: str): """ 异步方式删除指定实体及其相关的所有信息 参数: entity_name: 要删除的实体名称 """ # 标准化实体名称(转为大写并添加引号) entity_name = f'"{entity_name.upper()}"' try: # 依次删除实体在各个存储中的数据: # 1. 从实体向量数据库中删除 await self.entities_vdb.delete_entity(entity_name) # 2. 从关系向量数据库中删除相关关系 await self.relationships_vdb.delete_relation(entity_name) # 3. 从知识图谱中删除节点 await self.chunk_entity_relation_graph.delete_node(entity_name) # 记录删除成功的日志 logger.info( f"Entity '{entity_name}' and its relationships have been deleted." ) # 执行删除完成后的清理工作 await self._delete_by_entity_done() except Exception as e: # 记录删除过程中的错误 logger.error(f"Error while deleting entity '{entity_name}': {e}") async def _delete_by_entity_done(self): """ 实体删除操作完成后的回调函数 负责更新所有相关存储实例的状态 """ tasks = [] # 遍历需要更新的存储实例: # - 实体向量数据库 # - 关系向量数据库 # - 实体关系图 for storage_inst in [ self.entities_vdb, self.relationships_vdb, self.chunk_entity_relation_graph, ]: if storage_inst is None: continue # 将每个存储实例的回调添加到任务列表 tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) # 并发执行所有回调任务 await asyncio.gather(*tasks)