511 lines
18 KiB
Python
511 lines
18 KiB
Python
"""
|
||
LightRAG - 轻量级检索增强生成系统
|
||
该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
from dataclasses import asdict, dataclass, field
|
||
from datetime import datetime
|
||
from functools import partial
|
||
from typing import Type, cast
|
||
|
||
# 导入LLM相关功能
|
||
from .llm import (
|
||
gpt_4o_mini_complete, # GPT模型完成功能
|
||
openai_embedding, # OpenAI文本嵌入功能
|
||
)
|
||
|
||
# 导入核心操作功能
|
||
from .operate import (
|
||
chunking_by_token_size, # 文本分块
|
||
extract_entities, # 实体提取
|
||
local_query, # 本地查询
|
||
global_query, # 全局查询
|
||
hybrid_query, # 混合查询
|
||
naive_query, # 简单查询
|
||
)
|
||
|
||
# 导入存储实现
|
||
from .storage import (
|
||
JsonKVStorage, # JSON键值存储
|
||
NanoVectorDBStorage, # 向量数据库存储
|
||
NetworkXStorage, # 图数据库存储
|
||
)
|
||
|
||
from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现
|
||
# 未来可能的图数据库集成
|
||
# from .kg.ArangoDB_impl import (
|
||
# GraphStorage as ArangoDBStorage
|
||
# )
|
||
|
||
# 导入工具函数
|
||
from .utils import (
|
||
EmbeddingFunc, # 嵌入函数类型
|
||
compute_mdhash_id, # 计算MD5哈希ID
|
||
limit_async_func_call, # 限制异步函数调用
|
||
convert_response_to_json, # 响应转JSON
|
||
logger, # 日志记录器
|
||
set_logger, # 设置日志
|
||
)
|
||
|
||
# 导入基类
|
||
from .base import (
|
||
BaseGraphStorage, # 图存储基类
|
||
BaseKVStorage, # 键值存储基类
|
||
BaseVectorStorage, # 向量存储基类
|
||
StorageNameSpace, # 存储命名空间
|
||
QueryParam, # 查询参数
|
||
)
|
||
|
||
|
||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||
"""
|
||
获取或创建事件循环
|
||
如果当前线程没有事件循环,则创建一个新的
|
||
返回值:
|
||
asyncio.AbstractEventLoop: 事件循环实例
|
||
"""
|
||
try:
|
||
return asyncio.get_event_loop()
|
||
except RuntimeError:
|
||
logger.info("Creating a new event loop in main thread.")
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
return loop
|
||
|
||
|
||
@dataclass
|
||
class LightRAG:
|
||
"""
|
||
轻量级检索增强生成(LightRAG)系统的主类
|
||
实现了文档的存储、检索、知识图谱构建和问答功能
|
||
"""
|
||
|
||
# 工作目录配置,用存储所有缓存文件
|
||
working_dir: str = field(
|
||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||
)
|
||
|
||
# 知识图谱存储类型,默认使用NetworkX实现
|
||
kg: str = field(default="NetworkXStorage")
|
||
|
||
# 日志级别设置
|
||
current_log_level = logger.level
|
||
log_level: str = field(default=current_log_level)
|
||
|
||
# 文本分块参数配置
|
||
chunk_token_size: int = 1200 # 每个文本块的目标token数
|
||
chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数
|
||
tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称
|
||
|
||
# 实体提取参数
|
||
entity_extract_max_gleaning: int = 1 # 最大实体提取次数
|
||
entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数
|
||
|
||
# 节点嵌入配置
|
||
node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择
|
||
node2vec_params: dict = field(
|
||
default_factory=lambda: {
|
||
"dimensions": 1536, # 嵌入向量维度
|
||
"num_walks": 10, # 每个节点的随机游走次数
|
||
"walk_length": 40, # 每次随机游走的长度
|
||
"window_size": 2, # 上下文窗口大小
|
||
"iterations": 3, # 训练迭代次数
|
||
"random_seed": 3, # 随机种子
|
||
}
|
||
)
|
||
|
||
# 文本嵌入配置
|
||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型
|
||
embedding_batch_num: int = 32 # 批处理大小
|
||
embedding_func_max_async: int = 16 # 最大并发请求数
|
||
|
||
# 语言模型配置
|
||
llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型
|
||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称
|
||
llm_model_max_token_size: int = 32768 # 模型最大token限制
|
||
llm_model_max_async: int = 16 # 最大并发请求数
|
||
llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数
|
||
|
||
# 存储配置
|
||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类
|
||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类
|
||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数
|
||
enable_llm_cache: bool = True # 是否启用语言模型缓存
|
||
|
||
# 扩展配置
|
||
addon_params: dict = field(default_factory=dict) # 附加参数
|
||
convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数
|
||
|
||
def __post_init__(self):
|
||
"""
|
||
初始化方法,在对象创建后自动调用
|
||
负责设置日志、初始化存储系统和配置各种功能组件
|
||
"""
|
||
# 配置日志系统
|
||
log_file = os.path.join(self.working_dir, "lightrag.log")
|
||
set_logger(log_file)
|
||
logger.setLevel(self.log_level)
|
||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||
|
||
# 记录初始化参数
|
||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||
|
||
# 根据配置选择图存储实现类
|
||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
||
self.kg
|
||
]
|
||
|
||
# 确保工作目录存在
|
||
if not os.path.exists(self.working_dir):
|
||
logger.info(f"Creating working directory {self.working_dir}")
|
||
os.makedirs(self.working_dir)
|
||
|
||
# 初始化文档存储系统
|
||
self.full_docs = self.key_string_value_json_storage_cls(
|
||
namespace="full_docs", global_config=asdict(self)
|
||
)
|
||
|
||
# 初始化文本块存储系统
|
||
self.text_chunks = self.key_string_value_json_storage_cls(
|
||
namespace="text_chunks", global_config=asdict(self)
|
||
)
|
||
|
||
# 初始化语言模型响应缓存(如果启用)
|
||
self.llm_response_cache = (
|
||
self.key_string_value_json_storage_cls(
|
||
namespace="llm_response_cache", global_config=asdict(self)
|
||
)
|
||
if self.enable_llm_cache
|
||
else None
|
||
)
|
||
|
||
# 初始化实体关系图存储
|
||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||
namespace="chunk_entity_relation", global_config=asdict(self)
|
||
)
|
||
|
||
# 配置嵌入函数的并发限制
|
||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||
self.embedding_func
|
||
)
|
||
|
||
# 初始化向量数据库存储系统
|
||
# 用于存储实体的向量表示
|
||
self.entities_vdb = self.vector_db_storage_cls(
|
||
namespace="entities",
|
||
global_config=asdict(self),
|
||
embedding_func=self.embedding_func,
|
||
meta_fields={"entity_name"},
|
||
)
|
||
# 用于存储关系的向量表示
|
||
self.relationships_vdb = self.vector_db_storage_cls(
|
||
namespace="relationships",
|
||
global_config=asdict(self),
|
||
embedding_func=self.embedding_func,
|
||
meta_fields={"src_id", "tgt_id"},
|
||
)
|
||
# 用于存储文本块的向量表示
|
||
self.chunks_vdb = self.vector_db_storage_cls(
|
||
namespace="chunks",
|
||
global_config=asdict(self),
|
||
embedding_func=self.embedding_func,
|
||
)
|
||
|
||
# 配置语言模型函数的并发限制和缓存
|
||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||
partial(
|
||
self.llm_model_func,
|
||
hashing_kv=self.llm_response_cache,
|
||
**self.llm_model_kwargs,
|
||
)
|
||
)
|
||
|
||
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
||
"""
|
||
获取图存储类的实现
|
||
根据配置选择合适的图存储后端(Neo4J或NetworkX)
|
||
|
||
返回值:
|
||
Type[BaseGraphStorage]: 图存储类
|
||
"""
|
||
return {
|
||
"Neo4JStorage": Neo4JStorage,
|
||
"NetworkXStorage": NetworkXStorage,
|
||
}
|
||
|
||
def insert(self, string_or_strings):
|
||
"""
|
||
同步方式插入文档
|
||
将字符串或字符串列表插入到系统中进行处理
|
||
|
||
参数:
|
||
string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容
|
||
"""
|
||
loop = always_get_an_event_loop()
|
||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||
|
||
async def ainsert(self, string_or_strings):
|
||
"""
|
||
异步方式插入文档
|
||
处理文档内容,包括分块、实体提取和向量化存储
|
||
|
||
参数:
|
||
string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容
|
||
"""
|
||
try:
|
||
# 确保输入是列表形式
|
||
if isinstance(string_or_strings, str):
|
||
string_or_strings = [string_or_strings]
|
||
|
||
# 为每个文档生成唯一ID并创建文档字典
|
||
new_docs = {
|
||
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
||
for c in string_or_strings
|
||
}
|
||
|
||
# 过滤掉已存在的文档
|
||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||
|
||
if not len(new_docs):
|
||
logger.warning("All docs are already in the storage")
|
||
return
|
||
|
||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||
|
||
# 处理文档分块
|
||
inserting_chunks = {}
|
||
for doc_key, doc in new_docs.items():
|
||
# 对每个文档进行分块,并为每个块生成唯一ID
|
||
chunks = {
|
||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||
**dp,
|
||
"full_doc_id": doc_key,
|
||
}
|
||
for dp in chunking_by_token_size(
|
||
doc["content"],
|
||
overlap_token_size=self.chunk_overlap_token_size,
|
||
max_token_size=self.chunk_token_size,
|
||
tiktoken_model=self.tiktoken_model_name,
|
||
)
|
||
}
|
||
inserting_chunks.update(chunks)
|
||
|
||
# 过滤掉已存在的文本块
|
||
_add_chunk_keys = await self.text_chunks.filter_keys(
|
||
list(inserting_chunks.keys())
|
||
)
|
||
inserting_chunks = {
|
||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||
}
|
||
|
||
if not len(inserting_chunks):
|
||
logger.warning("All chunks are already in the storage")
|
||
return
|
||
|
||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||
|
||
# 将新的文本块插入向量数据库
|
||
await self.chunks_vdb.upsert(inserting_chunks)
|
||
|
||
# 提取实体和关系并更新知识图谱
|
||
logger.info("[Entity Extraction]...")
|
||
maybe_new_kg = await extract_entities(
|
||
inserting_chunks,
|
||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||
entity_vdb=self.entities_vdb,
|
||
relationships_vdb=self.relationships_vdb,
|
||
global_config=asdict(self),
|
||
)
|
||
|
||
if maybe_new_kg is None:
|
||
logger.warning("No new entities and relationships found")
|
||
return
|
||
|
||
self.chunk_entity_relation_graph = maybe_new_kg
|
||
|
||
# 更新文档和文本块存储
|
||
await self.full_docs.upsert(new_docs)
|
||
await self.text_chunks.upsert(inserting_chunks)
|
||
finally:
|
||
# 完成插入后执行清理工作
|
||
await self._insert_done()
|
||
|
||
async def _insert_done(self):
|
||
"""
|
||
插入操作完成后的回调函数
|
||
负责更新所有存储实例的索引并保存状态
|
||
"""
|
||
tasks = []
|
||
# 遍历所有需要执行回调的存储实例
|
||
for storage_inst in [
|
||
self.full_docs, # 完整文档存储
|
||
self.text_chunks, # 文本块存储
|
||
self.llm_response_cache, # LLM响应缓存
|
||
self.entities_vdb, # 实体向量数据库
|
||
self.relationships_vdb, # 关系向量数据库
|
||
self.chunks_vdb, # 文本块向量<E59091><E9878F><EFBFBD>据库
|
||
self.chunk_entity_relation_graph, # 实体关系图
|
||
]:
|
||
if storage_inst is None:
|
||
continue
|
||
# 将每个存储实例的回调任务添加到任务列表
|
||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||
# 并发执行所有回调任务
|
||
await asyncio.gather(*tasks)
|
||
|
||
def query(self, query: str, param: QueryParam = QueryParam()):
|
||
"""
|
||
同步方式执行查询
|
||
|
||
参数:
|
||
query: 查询文本
|
||
param: 查询参数配置
|
||
返回值:
|
||
查询结果
|
||
"""
|
||
loop = always_get_an_event_loop()
|
||
return loop.run_until_complete(self.aquery(query, param))
|
||
|
||
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
||
"""
|
||
异步方式执行查询
|
||
支持多种查询模式:本地查询、全局查询、混合查询和简单查询
|
||
|
||
参数:
|
||
query: 查询文本
|
||
param: 查询参数配置
|
||
返回值:
|
||
查询结果
|
||
"""
|
||
# 根据查询模式选择相应的查询方法
|
||
if param.mode == "local":
|
||
# 本地查询:主要基于局部上下文
|
||
response = await local_query(
|
||
query,
|
||
self.chunk_entity_relation_graph,
|
||
self.entities_vdb,
|
||
self.relationships_vdb,
|
||
self.text_chunks,
|
||
param,
|
||
asdict(self),
|
||
)
|
||
elif param.mode == "global":
|
||
# 全局查询:考虑整个知识图谱
|
||
response = await global_query(
|
||
query,
|
||
self.chunk_entity_relation_graph,
|
||
self.entities_vdb,
|
||
self.relationships_vdb,
|
||
self.text_chunks,
|
||
param,
|
||
asdict(self),
|
||
)
|
||
elif param.mode == "hybrid":
|
||
# 混合查询:结合局部和全局信息
|
||
response = await hybrid_query(
|
||
query,
|
||
self.chunk_entity_relation_graph,
|
||
self.entities_vdb,
|
||
self.relationships_vdb,
|
||
self.text_chunks,
|
||
param,
|
||
asdict(self),
|
||
)
|
||
elif param.mode == "naive":
|
||
# 简单查询:直接基于文本相似度
|
||
response = await naive_query(
|
||
query,
|
||
self.chunks_vdb,
|
||
self.text_chunks,
|
||
param,
|
||
asdict(self),
|
||
)
|
||
else:
|
||
raise ValueError(f"Unknown mode {param.mode}")
|
||
|
||
# 执行查询完成后的清理工作
|
||
await self._query_done()
|
||
return response
|
||
|
||
async def _query_done(self):
|
||
"""
|
||
查询操作完成后的回调函数
|
||
主要用于更新LLM响应缓存的状态
|
||
"""
|
||
tasks = []
|
||
# 目前只需要处理LLM响应缓存的回调
|
||
for storage_inst in [self.llm_response_cache]:
|
||
if storage_inst is None:
|
||
continue
|
||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||
# 并发执行所有回调任务
|
||
await asyncio.gather(*tasks)
|
||
|
||
def delete_by_entity(self, entity_name: str):
|
||
"""
|
||
同步方式删除指定实体
|
||
|
||
参数:
|
||
entity_name: 要删除的实体名称
|
||
"""
|
||
loop = always_get_an_event_loop()
|
||
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
||
|
||
async def adelete_by_entity(self, entity_name: str):
|
||
"""
|
||
异步方式删除指定实体及其相关的所有信息
|
||
|
||
参数:
|
||
entity_name: 要删除的实体名称
|
||
"""
|
||
# 标准化实体名称(转为大写并添加引号)
|
||
entity_name = f'"{entity_name.upper()}"'
|
||
try:
|
||
# 依次删除实体在各个存储中的数据:
|
||
# 1. 从实体向量数据库中删除
|
||
await self.entities_vdb.delete_entity(entity_name)
|
||
# 2. 从关系向量数据库中删除相关关系
|
||
await self.relationships_vdb.delete_relation(entity_name)
|
||
# 3. 从知识图谱中删除节点
|
||
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
||
|
||
# 记录删除成功的日志
|
||
logger.info(
|
||
f"Entity '{entity_name}' and its relationships have been deleted."
|
||
)
|
||
# 执行删除完成后的清理工作
|
||
await self._delete_by_entity_done()
|
||
except Exception as e:
|
||
# 记录删除过程中的错误
|
||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||
|
||
async def _delete_by_entity_done(self):
|
||
"""
|
||
实体删除操作完成后的回调函数
|
||
负责更新所有相关存储实例的状态
|
||
"""
|
||
tasks = []
|
||
# 遍历需要更新的存储实例:
|
||
# - 实体向量数据库
|
||
# - 关系向量数据库
|
||
# - 实体关系图
|
||
for storage_inst in [
|
||
self.entities_vdb,
|
||
self.relationships_vdb,
|
||
self.chunk_entity_relation_graph,
|
||
]:
|
||
if storage_inst is None:
|
||
continue
|
||
# 将每个存储实例的回调添加到任务列表
|
||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||
# 并发执行所有回调任务
|
||
await asyncio.gather(*tasks)
|
||
|
||
|
||
|
||
|