2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
LightRAG - 轻量级检索增强生成系统
|
|
|
|
|
该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建
|
|
|
|
|
"""
|
|
|
|
|
|
2024-11-16 11:59:20 +08:00
|
|
|
|
# 导入异步IO模块,用于处理异步编程
|
2024-11-16 11:26:57 +08:00
|
|
|
|
import asyncio
|
2024-11-16 11:59:20 +08:00
|
|
|
|
|
|
|
|
|
# 导入操作系统接口模块,用于处理文件路径和环境变量
|
2024-11-16 11:26:57 +08:00
|
|
|
|
import os
|
2024-11-16 11:59:20 +08:00
|
|
|
|
|
|
|
|
|
# 从dataclasses模块导入数据类相关工具
|
|
|
|
|
from dataclasses import (
|
|
|
|
|
asdict, # 将数据类实例转换为字典的函数
|
|
|
|
|
dataclass, # 数据类装饰器
|
|
|
|
|
field, # 用于定义数据类字段的函数
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 导入日期时间处理模块
|
2024-11-16 11:26:57 +08:00
|
|
|
|
from datetime import datetime
|
2024-11-16 11:59:20 +08:00
|
|
|
|
|
|
|
|
|
# 从functools导入partial函数,用于创建偏函数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
from functools import partial
|
2024-11-16 11:59:20 +08:00
|
|
|
|
|
|
|
|
|
# 从typing模块导入类型提示工具
|
|
|
|
|
from typing import (
|
|
|
|
|
Type, # 用于类型注解中表示类型的类型
|
|
|
|
|
cast, # 用于类型转换的函数
|
|
|
|
|
)
|
2024-11-16 11:26:57 +08:00
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 导入LLM相关功能
|
2024-11-16 11:26:57 +08:00
|
|
|
|
from .llm import (
|
2024-11-16 11:29:02 +08:00
|
|
|
|
gpt_4o_mini_complete, # GPT模型完成功能
|
|
|
|
|
openai_embedding, # OpenAI文本嵌入功能
|
2024-11-16 11:26:57 +08:00
|
|
|
|
)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
# 导入核心操作功能
|
2024-11-16 11:26:57 +08:00
|
|
|
|
from .operate import (
|
2024-11-16 11:29:02 +08:00
|
|
|
|
chunking_by_token_size, # 文本分块
|
|
|
|
|
extract_entities, # 实体提取
|
|
|
|
|
local_query, # 本地查询
|
|
|
|
|
global_query, # 全局查询
|
|
|
|
|
hybrid_query, # 混合查询
|
|
|
|
|
naive_query, # 简单查询
|
2024-11-16 11:26:57 +08:00
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 导入存储实现
|
2024-11-16 11:26:57 +08:00
|
|
|
|
from .storage import (
|
2024-11-16 11:29:02 +08:00
|
|
|
|
JsonKVStorage, # JSON键值存储
|
|
|
|
|
NanoVectorDBStorage, # 向量数据库存储
|
|
|
|
|
NetworkXStorage, # 图数据库存储
|
2024-11-16 11:26:57 +08:00
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现
|
|
|
|
|
# 未来可能的图数据库集成
|
2024-11-16 11:26:57 +08:00
|
|
|
|
# from .kg.ArangoDB_impl import (
|
|
|
|
|
# GraphStorage as ArangoDBStorage
|
|
|
|
|
# )
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 导入工具函数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
from .utils import (
|
2024-11-16 11:29:02 +08:00
|
|
|
|
EmbeddingFunc, # 嵌入函数类型
|
|
|
|
|
compute_mdhash_id, # 计算MD5哈希ID
|
|
|
|
|
limit_async_func_call, # 限制异步函数调用
|
|
|
|
|
convert_response_to_json, # 响应转JSON
|
|
|
|
|
logger, # 日志记录器
|
|
|
|
|
set_logger, # 设置日志
|
2024-11-16 11:26:57 +08:00
|
|
|
|
)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
# 导入基类
|
2024-11-16 11:26:57 +08:00
|
|
|
|
from .base import (
|
2024-11-16 11:29:02 +08:00
|
|
|
|
BaseGraphStorage, # 图存储基类
|
|
|
|
|
BaseKVStorage, # 键值存储基类
|
|
|
|
|
BaseVectorStorage, # 向量存储基类
|
|
|
|
|
StorageNameSpace, # 存储命名空间
|
|
|
|
|
QueryParam, # 查询参数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
获取或创建事件循环
|
|
|
|
|
如果当前线程没有事件循环,则创建一个新的
|
|
|
|
|
返回值:
|
|
|
|
|
asyncio.AbstractEventLoop: 事件循环实例
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
try:
|
|
|
|
|
return asyncio.get_event_loop()
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
logger.info("Creating a new event loop in main thread.")
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
return loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class LightRAG:
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
轻量级检索增强生成(LightRAG)系统的主类
|
|
|
|
|
实现了文档的存储、检索、知识图谱构建和问答功能
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 工作目录配置,用存储所有缓存文件
|
2024-11-16 11:26:57 +08:00
|
|
|
|
working_dir: str = field(
|
|
|
|
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 知识图谱存储类型,默认使用NetworkX实现
|
2024-11-16 11:26:57 +08:00
|
|
|
|
kg: str = field(default="NetworkXStorage")
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 日志级别设置
|
2024-11-16 11:26:57 +08:00
|
|
|
|
current_log_level = logger.level
|
|
|
|
|
log_level: str = field(default=current_log_level)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 文本分块参数配置
|
|
|
|
|
chunk_token_size: int = 1200 # 每个文本块的目标token数
|
|
|
|
|
chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数
|
|
|
|
|
tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称
|
2024-11-16 11:26:57 +08:00
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 实体提取参数
|
|
|
|
|
entity_extract_max_gleaning: int = 1 # 最大实体提取次数
|
|
|
|
|
entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 节点嵌入配置
|
|
|
|
|
node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择
|
2024-11-16 11:26:57 +08:00
|
|
|
|
node2vec_params: dict = field(
|
|
|
|
|
default_factory=lambda: {
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"dimensions": 1536, # 嵌入向量维度
|
|
|
|
|
"num_walks": 10, # 每个节点的随机游走次数
|
|
|
|
|
"walk_length": 40, # 每次随机游走的长度
|
|
|
|
|
"window_size": 2, # 上下文窗口大小
|
|
|
|
|
"iterations": 3, # 训练迭代次数
|
|
|
|
|
"random_seed": 3, # 随机种子
|
2024-11-16 11:26:57 +08:00
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 文本嵌入配置
|
|
|
|
|
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型
|
|
|
|
|
embedding_batch_num: int = 32 # 批处理大小
|
|
|
|
|
embedding_func_max_async: int = 16 # 最大并发请求数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 语言模型配置
|
|
|
|
|
llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型
|
|
|
|
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称
|
|
|
|
|
llm_model_max_token_size: int = 32768 # 模型最大token限制
|
|
|
|
|
llm_model_max_async: int = 16 # 最大并发请求数
|
|
|
|
|
llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 存储配置
|
|
|
|
|
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类
|
|
|
|
|
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类
|
|
|
|
|
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数
|
|
|
|
|
enable_llm_cache: bool = True # 是否启用语言模型缓存
|
2024-11-16 11:26:57 +08:00
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 扩展配置
|
|
|
|
|
addon_params: dict = field(default_factory=dict) # 附加参数
|
|
|
|
|
convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
初始化方法,在对象创建后自动调用
|
|
|
|
|
负责设置日志、初始化存储系统和配置各种功能组件
|
|
|
|
|
"""
|
|
|
|
|
# 配置日志系统
|
2024-11-16 11:26:57 +08:00
|
|
|
|
log_file = os.path.join(self.working_dir, "lightrag.log")
|
|
|
|
|
set_logger(log_file)
|
|
|
|
|
logger.setLevel(self.log_level)
|
|
|
|
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 记录初始化参数
|
2024-11-16 11:26:57 +08:00
|
|
|
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
|
|
|
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 根据配置选择图存储实现类
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
|
|
|
|
self.kg
|
|
|
|
|
]
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 确保工作目录存在
|
2024-11-16 11:26:57 +08:00
|
|
|
|
if not os.path.exists(self.working_dir):
|
|
|
|
|
logger.info(f"Creating working directory {self.working_dir}")
|
|
|
|
|
os.makedirs(self.working_dir)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 初始化文档存储系统
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.full_docs = self.key_string_value_json_storage_cls(
|
|
|
|
|
namespace="full_docs", global_config=asdict(self)
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 初始化文本块存储系统
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.text_chunks = self.key_string_value_json_storage_cls(
|
|
|
|
|
namespace="text_chunks", global_config=asdict(self)
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 初始化语言模型响应缓存(如果启用)
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.llm_response_cache = (
|
|
|
|
|
self.key_string_value_json_storage_cls(
|
|
|
|
|
namespace="llm_response_cache", global_config=asdict(self)
|
|
|
|
|
)
|
|
|
|
|
if self.enable_llm_cache
|
|
|
|
|
else None
|
|
|
|
|
)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
# 初始化实体关系图存储
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
|
|
|
|
namespace="chunk_entity_relation", global_config=asdict(self)
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 配置嵌入函数的并发限制
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
|
|
|
|
self.embedding_func
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 初始化向量数据库存储系统
|
|
|
|
|
# 用于存储实体的向量表示
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.entities_vdb = self.vector_db_storage_cls(
|
|
|
|
|
namespace="entities",
|
|
|
|
|
global_config=asdict(self),
|
|
|
|
|
embedding_func=self.embedding_func,
|
|
|
|
|
meta_fields={"entity_name"},
|
|
|
|
|
)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 用于存储关系的向量表示
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.relationships_vdb = self.vector_db_storage_cls(
|
|
|
|
|
namespace="relationships",
|
|
|
|
|
global_config=asdict(self),
|
|
|
|
|
embedding_func=self.embedding_func,
|
|
|
|
|
meta_fields={"src_id", "tgt_id"},
|
|
|
|
|
)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 用于存储文本块的向量表示
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.chunks_vdb = self.vector_db_storage_cls(
|
|
|
|
|
namespace="chunks",
|
|
|
|
|
global_config=asdict(self),
|
|
|
|
|
embedding_func=self.embedding_func,
|
|
|
|
|
)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 配置语言模型函数的并发限制和缓存
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
|
|
|
|
partial(
|
|
|
|
|
self.llm_model_func,
|
|
|
|
|
hashing_kv=self.llm_response_cache,
|
|
|
|
|
**self.llm_model_kwargs,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
获取图存储类的实现
|
|
|
|
|
根据配置选择合适的图存储后端(Neo4J或NetworkX)
|
|
|
|
|
|
|
|
|
|
返回值:
|
|
|
|
|
Type[BaseGraphStorage]: 图存储类
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
return {
|
|
|
|
|
"Neo4JStorage": Neo4JStorage,
|
|
|
|
|
"NetworkXStorage": NetworkXStorage,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def insert(self, string_or_strings):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
同步方式插入文档
|
|
|
|
|
将字符串或字符串列表插入到系统中进行处理
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
loop = always_get_an_event_loop()
|
|
|
|
|
return loop.run_until_complete(self.ainsert(string_or_strings))
|
|
|
|
|
|
|
|
|
|
async def ainsert(self, string_or_strings):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
异步方式插入文档
|
|
|
|
|
处理文档内容,包括分块、实体提取和向量化存储
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
try:
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 确保输入是列表形式
|
2024-11-16 11:26:57 +08:00
|
|
|
|
if isinstance(string_or_strings, str):
|
|
|
|
|
string_or_strings = [string_or_strings]
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 为每个文档生成唯一ID并创建文档字典
|
2024-11-16 11:26:57 +08:00
|
|
|
|
new_docs = {
|
|
|
|
|
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
|
|
|
|
for c in string_or_strings
|
|
|
|
|
}
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
# 过滤掉已存在的文档
|
2024-11-16 11:26:57 +08:00
|
|
|
|
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
|
|
|
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
2024-11-16 11:26:57 +08:00
|
|
|
|
if not len(new_docs):
|
|
|
|
|
logger.warning("All docs are already in the storage")
|
|
|
|
|
return
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
2024-11-16 11:26:57 +08:00
|
|
|
|
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 处理文档分块
|
2024-11-16 11:26:57 +08:00
|
|
|
|
inserting_chunks = {}
|
|
|
|
|
for doc_key, doc in new_docs.items():
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 对每个文档进行分块,并为每个块生成唯一ID
|
2024-11-16 11:26:57 +08:00
|
|
|
|
chunks = {
|
|
|
|
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
|
|
|
|
**dp,
|
|
|
|
|
"full_doc_id": doc_key,
|
|
|
|
|
}
|
|
|
|
|
for dp in chunking_by_token_size(
|
|
|
|
|
doc["content"],
|
|
|
|
|
overlap_token_size=self.chunk_overlap_token_size,
|
|
|
|
|
max_token_size=self.chunk_token_size,
|
|
|
|
|
tiktoken_model=self.tiktoken_model_name,
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
inserting_chunks.update(chunks)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
# 过滤掉已存在的文本块
|
2024-11-16 11:26:57 +08:00
|
|
|
|
_add_chunk_keys = await self.text_chunks.filter_keys(
|
|
|
|
|
list(inserting_chunks.keys())
|
|
|
|
|
)
|
|
|
|
|
inserting_chunks = {
|
|
|
|
|
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
|
|
|
|
}
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
2024-11-16 11:26:57 +08:00
|
|
|
|
if not len(inserting_chunks):
|
|
|
|
|
logger.warning("All chunks are already in the storage")
|
|
|
|
|
return
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
2024-11-16 11:26:57 +08:00
|
|
|
|
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 将新的文本块插入向量数据库
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self.chunks_vdb.upsert(inserting_chunks)
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 提取实体和关系并更新知识图谱
|
2024-11-16 11:26:57 +08:00
|
|
|
|
logger.info("[Entity Extraction]...")
|
|
|
|
|
maybe_new_kg = await extract_entities(
|
|
|
|
|
inserting_chunks,
|
|
|
|
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
|
|
|
|
entity_vdb=self.entities_vdb,
|
|
|
|
|
relationships_vdb=self.relationships_vdb,
|
|
|
|
|
global_config=asdict(self),
|
|
|
|
|
)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
2024-11-16 11:26:57 +08:00
|
|
|
|
if maybe_new_kg is None:
|
|
|
|
|
logger.warning("No new entities and relationships found")
|
|
|
|
|
return
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
2024-11-16 11:26:57 +08:00
|
|
|
|
self.chunk_entity_relation_graph = maybe_new_kg
|
|
|
|
|
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 更新文档和文本块存储
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self.full_docs.upsert(new_docs)
|
|
|
|
|
await self.text_chunks.upsert(inserting_chunks)
|
|
|
|
|
finally:
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 完成插入后执行清理工作
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self._insert_done()
|
|
|
|
|
|
|
|
|
|
async def _insert_done(self):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
插入操作完成后的回调函数
|
|
|
|
|
负责更新所有存储实例的索引并保存状态
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
tasks = []
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 遍历所有需要执行回调的存储实例
|
2024-11-16 11:26:57 +08:00
|
|
|
|
for storage_inst in [
|
2024-11-16 11:29:02 +08:00
|
|
|
|
self.full_docs, # 完整文档存储
|
|
|
|
|
self.text_chunks, # 文本块存储
|
|
|
|
|
self.llm_response_cache, # LLM响应缓存
|
|
|
|
|
self.entities_vdb, # 实体向量数据库
|
|
|
|
|
self.relationships_vdb, # 关系向量数据库
|
|
|
|
|
self.chunks_vdb, # 文本块向量<E59091><E9878F><EFBFBD>据库
|
|
|
|
|
self.chunk_entity_relation_graph, # 实体关系图
|
2024-11-16 11:26:57 +08:00
|
|
|
|
]:
|
|
|
|
|
if storage_inst is None:
|
|
|
|
|
continue
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 将每个存储实例的回调任务添加到任务列表
|
2024-11-16 11:26:57 +08:00
|
|
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 并发执行所有回调任务
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
def query(self, query: str, param: QueryParam = QueryParam()):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
同步方式执行查询
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
query: 查询文本
|
|
|
|
|
param: 查询参数配置
|
|
|
|
|
返回值:
|
|
|
|
|
查询结果
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
loop = always_get_an_event_loop()
|
|
|
|
|
return loop.run_until_complete(self.aquery(query, param))
|
|
|
|
|
|
|
|
|
|
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
异步方式执行查询
|
|
|
|
|
支持多种查询模式:本地查询、全局查询、混合查询和简单查询
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
query: 查询文本
|
|
|
|
|
param: 查询参数配置
|
|
|
|
|
返回值:
|
|
|
|
|
查询结果
|
|
|
|
|
"""
|
|
|
|
|
# 根据查询模式选择相应的查询方法
|
2024-11-16 11:26:57 +08:00
|
|
|
|
if param.mode == "local":
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 本地查询:主要基于局部上下文
|
2024-11-16 11:26:57 +08:00
|
|
|
|
response = await local_query(
|
|
|
|
|
query,
|
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
|
self.entities_vdb,
|
|
|
|
|
self.relationships_vdb,
|
|
|
|
|
self.text_chunks,
|
|
|
|
|
param,
|
|
|
|
|
asdict(self),
|
|
|
|
|
)
|
|
|
|
|
elif param.mode == "global":
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 全局查询:考虑整个知识图谱
|
2024-11-16 11:26:57 +08:00
|
|
|
|
response = await global_query(
|
|
|
|
|
query,
|
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
|
self.entities_vdb,
|
|
|
|
|
self.relationships_vdb,
|
|
|
|
|
self.text_chunks,
|
|
|
|
|
param,
|
|
|
|
|
asdict(self),
|
|
|
|
|
)
|
|
|
|
|
elif param.mode == "hybrid":
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 混合查询:结合局部和全局信息
|
2024-11-16 11:26:57 +08:00
|
|
|
|
response = await hybrid_query(
|
|
|
|
|
query,
|
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
|
self.entities_vdb,
|
|
|
|
|
self.relationships_vdb,
|
|
|
|
|
self.text_chunks,
|
|
|
|
|
param,
|
|
|
|
|
asdict(self),
|
|
|
|
|
)
|
|
|
|
|
elif param.mode == "naive":
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 简单查询:直接基于文本相似度
|
2024-11-16 11:26:57 +08:00
|
|
|
|
response = await naive_query(
|
|
|
|
|
query,
|
|
|
|
|
self.chunks_vdb,
|
|
|
|
|
self.text_chunks,
|
|
|
|
|
param,
|
|
|
|
|
asdict(self),
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown mode {param.mode}")
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
# 执行查询完成后的清理工作
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self._query_done()
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
async def _query_done(self):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
查询操作完成后的回调函数
|
|
|
|
|
主要用于更新LLM响应缓存的状态
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
tasks = []
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 目前只需要处理LLM响应缓存的回调
|
2024-11-16 11:26:57 +08:00
|
|
|
|
for storage_inst in [self.llm_response_cache]:
|
|
|
|
|
if storage_inst is None:
|
|
|
|
|
continue
|
|
|
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 并发执行所有回调任务
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
def delete_by_entity(self, entity_name: str):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
同步方式删除指定实体
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
entity_name: 要删除的实体名称
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
loop = always_get_an_event_loop()
|
|
|
|
|
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
|
|
|
|
|
|
|
|
|
async def adelete_by_entity(self, entity_name: str):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
异步方式删除指定实体及其相关的所有信息
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
entity_name: 要删除的实体名称
|
|
|
|
|
"""
|
|
|
|
|
# 标准化实体名称(转为大写并添加引号)
|
2024-11-16 11:26:57 +08:00
|
|
|
|
entity_name = f'"{entity_name.upper()}"'
|
|
|
|
|
try:
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 依次删除实体在各个存储中的数据:
|
|
|
|
|
# 1. 从实体向量数据库中删除
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self.entities_vdb.delete_entity(entity_name)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 2. 从关系向量数据库中删除相关关系
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self.relationships_vdb.delete_relation(entity_name)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 3. 从知识图谱中删除节点
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
# 记录删除成功的日志
|
2024-11-16 11:26:57 +08:00
|
|
|
|
logger.info(
|
|
|
|
|
f"Entity '{entity_name}' and its relationships have been deleted."
|
|
|
|
|
)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 执行删除完成后的清理工作
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await self._delete_by_entity_done()
|
|
|
|
|
except Exception as e:
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 记录删除过程中的错误
|
2024-11-16 11:26:57 +08:00
|
|
|
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
|
|
|
|
|
|
|
|
|
async def _delete_by_entity_done(self):
|
2024-11-16 11:29:02 +08:00
|
|
|
|
"""
|
|
|
|
|
实体删除操作完成后的回调函数
|
|
|
|
|
负责更新所有相关存储实例的状态
|
|
|
|
|
"""
|
2024-11-16 11:26:57 +08:00
|
|
|
|
tasks = []
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 遍历需要更新的存储实例:
|
|
|
|
|
# - 实体向量数据库
|
|
|
|
|
# - 关系向量数据库
|
|
|
|
|
# - 实体关系图
|
2024-11-16 11:26:57 +08:00
|
|
|
|
for storage_inst in [
|
|
|
|
|
self.entities_vdb,
|
|
|
|
|
self.relationships_vdb,
|
|
|
|
|
self.chunk_entity_relation_graph,
|
|
|
|
|
]:
|
|
|
|
|
if storage_inst is None:
|
|
|
|
|
continue
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 将每个存储实例的回调添加到任务列表
|
2024-11-16 11:26:57 +08:00
|
|
|
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
2024-11-16 11:29:02 +08:00
|
|
|
|
# 并发执行所有回调任务
|
2024-11-16 11:26:57 +08:00
|
|
|
|
await asyncio.gather(*tasks)
|
2024-11-16 11:29:02 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|