lightrag-comments/lightrag/lightrag.py

529 lines
19 KiB
Python
Raw Permalink Normal View History

"""
LightRAG - 轻量级检索增强生成系统
该模块实现了一个基于图的文档检索和问答系统支持文档的存储检索和知识图谱构建
"""
2024-11-16 11:59:20 +08:00
# 导入异步IO模块用于处理异步编程
import asyncio
2024-11-16 11:59:20 +08:00
# 导入操作系统接口模块,用于处理文件路径和环境变量
import os
2024-11-16 11:59:20 +08:00
# 从dataclasses模块导入数据类相关工具
from dataclasses import (
asdict, # 将数据类实例转换为字典的函数
dataclass, # 数据类装饰器
field, # 用于定义数据类字段的函数
)
# 导入日期时间处理模块
from datetime import datetime
2024-11-16 11:59:20 +08:00
# 从functools导入partial函数用于创建偏函数
from functools import partial
2024-11-16 11:59:20 +08:00
# 从typing模块导入类型提示工具
from typing import (
Type, # 用于类型注解中表示类型的类型
cast, # 用于类型转换的函数
)
# 导入LLM相关功能
from .llm import (
gpt_4o_mini_complete, # GPT模型完成功能
openai_embedding, # OpenAI文本嵌入功能
)
# 导入核心操作功能
from .operate import (
chunking_by_token_size, # 文本分块
extract_entities, # 实体提取
local_query, # 本地查询
global_query, # 全局查询
hybrid_query, # 混合查询
naive_query, # 简单查询
)
# 导入存储实现
from .storage import (
JsonKVStorage, # JSON键值存储
NanoVectorDBStorage, # 向量数据库存储
NetworkXStorage, # 图数据库存储
)
from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现
# 未来可能的图数据库集成
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
# 导入工具函数
from .utils import (
EmbeddingFunc, # 嵌入函数类型
compute_mdhash_id, # 计算MD5哈希ID
limit_async_func_call, # 限制异步函数调用
convert_response_to_json, # 响应转JSON
logger, # 日志记录器
set_logger, # 设置日志
)
# 导入基类
from .base import (
BaseGraphStorage, # 图存储基类
BaseKVStorage, # 键值存储基类
BaseVectorStorage, # 向量存储基类
StorageNameSpace, # 存储命名空间
QueryParam, # 查询参数
)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
获取或创建事件循环
如果当前线程没有事件循环则创建一个新的
返回值:
asyncio.AbstractEventLoop: 事件循环实例
"""
try:
return asyncio.get_event_loop()
except RuntimeError:
logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
@dataclass
class LightRAG:
"""
轻量级检索增强生成(LightRAG)系统的主类
实现了文档的存储检索知识图谱构建和问答功能
"""
# 工作目录配置,用存储所有缓存文件
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
# 知识图谱存储类型默认使用NetworkX实现
kg: str = field(default="NetworkXStorage")
# 日志级别设置
current_log_level = logger.level
log_level: str = field(default=current_log_level)
# 文本分块参数配置
chunk_token_size: int = 1200 # 每个文本块的目标token数
chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数
tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称
# 实体提取参数
entity_extract_max_gleaning: int = 1 # 最大实体提取次数
entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数
# 节点嵌入配置
node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536, # 嵌入向量维度
"num_walks": 10, # 每个节点的随机游走次数
"walk_length": 40, # 每次随机游走的长度
"window_size": 2, # 上下文窗口大小
"iterations": 3, # 训练迭代次数
"random_seed": 3, # 随机种子
}
)
# 文本嵌入配置
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型
embedding_batch_num: int = 32 # 批处理大小
embedding_func_max_async: int = 16 # 最大并发请求数
# 语言模型配置
llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称
llm_model_max_token_size: int = 32768 # 模型最大token限制
llm_model_max_async: int = 16 # 最大并发请求数
llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数
# 存储配置
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数
enable_llm_cache: bool = True # 是否启用语言模型缓存
# 扩展配置
addon_params: dict = field(default_factory=dict) # 附加参数
convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数
def __post_init__(self):
"""
初始化方法在对象创建后自动调用
负责设置日志初始化存储系统和配置各种功能组件
"""
# 配置日志系统
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
# 记录初始化参数
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# 根据配置选择图存储实现类
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.kg
]
# 确保工作目录存在
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
# 初始化文档存储系统
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self)
)
# 初始化文本块存储系统
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self)
)
# 初始化语言模型响应缓存(如果启用)
self.llm_response_cache = (
self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self)
)
if self.enable_llm_cache
else None
)
# 初始化实体关系图存储
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self)
)
# 配置嵌入函数的并发限制
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
# 初始化向量数据库存储系统
# 用于存储实体的向量表示
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
# 用于存储关系的向量表示
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
# 用于存储文本块的向量表示
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
# 配置语言模型函数的并发限制和缓存
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func,
hashing_kv=self.llm_response_cache,
**self.llm_model_kwargs,
)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
"""
获取图存储类的实现
根据配置选择合适的图存储后端Neo4J或NetworkX
返回值:
Type[BaseGraphStorage]: 图存储类
"""
return {
"Neo4JStorage": Neo4JStorage,
"NetworkXStorage": NetworkXStorage,
}
def insert(self, string_or_strings):
"""
同步方式插入文档
将字符串或字符串列表插入到系统中进行处理
参数:
string_or_strings: 单个字符串或字符串列表表示要处理的文档内容
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
"""
异步方式插入文档
处理文档内容包括分块实体提取和向量化存储
参数:
string_or_strings: 单个字符串或字符串列表表示要处理的文档内容
"""
try:
# 确保输入是列表形式
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 为每个文档生成唯一ID并创建文档字典
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
}
# 过滤掉已存在的文档
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
# 处理文档分块
inserting_chunks = {}
for doc_key, doc in new_docs.items():
# 对每个文档进行分块并为每个块生成唯一ID
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
inserting_chunks.update(chunks)
# 过滤掉已存在的文本块
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
# 将新的文本块插入向量数据库
await self.chunks_vdb.upsert(inserting_chunks)
# 提取实体和关系并更新知识图谱
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# 更新文档和文本块存储
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
# 完成插入后执行清理工作
await self._insert_done()
async def _insert_done(self):
"""
插入操作完成后的回调函数
负责更新所有存储实例的索引并保存状态
"""
tasks = []
# 遍历所有需要执行回调的存储实例
for storage_inst in [
self.full_docs, # 完整文档存储
self.text_chunks, # 文本块存储
self.llm_response_cache, # LLM响应缓存
self.entities_vdb, # 实体向量数据库
self.relationships_vdb, # 关系向量数据库
self.chunks_vdb, # 文本块向量<E59091><E9878F><EFBFBD>据库
self.chunk_entity_relation_graph, # 实体关系图
]:
if storage_inst is None:
continue
# 将每个存储实例的回调任务添加到任务列表
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks)
def query(self, query: str, param: QueryParam = QueryParam()):
"""
同步方式执行查询
参数:
query: 查询文本
param: 查询参数配置
返回值:
查询结果
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
"""
异步方式执行查询
支持多种查询模式本地查询全局查询混合查询和简单查询
参数:
query: 查询文本
param: 查询参数配置
返回值:
查询结果
"""
# 根据查询模式选择相应的查询方法
if param.mode == "local":
# 本地查询:主要基于局部上下文
response = await local_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "global":
# 全局查询:考虑整个知识图谱
response = await global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hybrid":
# 混合查询:结合局部和全局信息
response = await hybrid_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "naive":
# 简单查询:直接基于文本相似度
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
# 执行查询完成后的清理工作
await self._query_done()
return response
async def _query_done(self):
"""
查询操作完成后的回调函数
主要用于更新LLM响应缓存的状态
"""
tasks = []
# 目前只需要处理LLM响应缓存的回调
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks)
def delete_by_entity(self, entity_name: str):
"""
同步方式删除指定实体
参数:
entity_name: 要删除的实体名称
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name))
async def adelete_by_entity(self, entity_name: str):
"""
异步方式删除指定实体及其相关的所有信息
参数:
entity_name: 要删除的实体名称
"""
# 标准化实体名称(转为大写并添加引号)
entity_name = f'"{entity_name.upper()}"'
try:
# 依次删除实体在各个存储中的数据:
# 1. 从实体向量数据库中删除
await self.entities_vdb.delete_entity(entity_name)
# 2. 从关系向量数据库中删除相关关系
await self.relationships_vdb.delete_relation(entity_name)
# 3. 从知识图谱中删除节点
await self.chunk_entity_relation_graph.delete_node(entity_name)
# 记录删除成功的日志
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
# 执行删除完成后的清理工作
await self._delete_by_entity_done()
except Exception as e:
# 记录删除过程中的错误
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(self):
"""
实体删除操作完成后的回调函数
负责更新所有相关存储实例的状态
"""
tasks = []
# 遍历需要更新的存储实例:
# - 实体向量数据库
# - 关系向量数据库
# - 实体关系图
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
# 将每个存储实例的回调添加到任务列表
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks)