lightrag.py使用通义灵码加注释。 其他三个文件使用cursor加注释。

This commit is contained in:
many2many 2024-11-16 11:29:02 +08:00
parent c0fa4da53d
commit c8ee7286cb
4 changed files with 1261 additions and 260 deletions

View File

@ -1,3 +1,8 @@
"""
LightRAG - 轻量级检索增强生成系统
该模块实现了一个基于图的文档检索和问答系统支持文档的存储检索和知识图谱构建
"""
import asyncio
import os
from dataclasses import asdict, dataclass, field
@ -5,144 +10,170 @@ from datetime import datetime
from functools import partial
from typing import Type, cast
# 导入LLM相关功能
from .llm import (
gpt_4o_mini_complete,
openai_embedding,
gpt_4o_mini_complete, # GPT模型完成功能
openai_embedding, # OpenAI文本嵌入功能
)
# 导入核心操作功能
from .operate import (
chunking_by_token_size,
extract_entities,
local_query,
global_query,
hybrid_query,
naive_query,
chunking_by_token_size, # 文本分块
extract_entities, # 实体提取
local_query, # 本地查询
global_query, # 全局查询
hybrid_query, # 混合查询
naive_query, # 简单查询
)
# 导入存储实现
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
JsonKVStorage, # JSON键值存储
NanoVectorDBStorage, # 向量数据库存储
NetworkXStorage, # 图数据库存储
)
from .kg.neo4j_impl import Neo4JStorage
# future KG integrations
from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现
# 未来可能的图数据库集成
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
# 导入工具函数
from .utils import (
EmbeddingFunc,
compute_mdhash_id,
limit_async_func_call,
convert_response_to_json,
logger,
set_logger,
EmbeddingFunc, # 嵌入函数类型
compute_mdhash_id, # 计算MD5哈希ID
limit_async_func_call, # 限制异步函数调用
convert_response_to_json, # 响应转JSON
logger, # 日志记录器
set_logger, # 设置日志
)
# 导入基类
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
StorageNameSpace,
QueryParam,
BaseGraphStorage, # 图存储基类
BaseKVStorage, # 键值存储基类
BaseVectorStorage, # 向量存储基类
StorageNameSpace, # 存储命名空间
QueryParam, # 查询参数
)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
获取或创建事件循环
如果当前线程没有事件循环则创建一个新的
返回值:
asyncio.AbstractEventLoop: 事件循环实例
"""
try:
return asyncio.get_event_loop()
except RuntimeError:
logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
@dataclass
class LightRAG:
"""
轻量级检索增强生成(LightRAG)系统的主类
实现了文档的存储检索知识图谱构建和问答功能
"""
# 工作目录配置,用存储所有缓存文件
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
# 知识图谱存储类型默认使用NetworkX实现
kg: str = field(default="NetworkXStorage")
# 日志级别设置
current_log_level = logger.level
log_level: str = field(default=current_log_level)
# text chunking
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
tiktoken_model_name: str = "gpt-4o-mini"
# 文本分块参数配置
chunk_token_size: int = 1200 # 每个文本块的目标token数
chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数
tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称
# entity extraction
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# 实体提取参数
entity_extract_max_gleaning: int = 1 # 最大实体提取次数
entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数
# node embedding
node_embedding_algorithm: str = "node2vec"
# 节点嵌入配置
node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
"dimensions": 1536, # 嵌入向量维度
"num_walks": 10, # 每个节点的随机游走次数
"walk_length": 40, # 每次随机游走的长度
"window_size": 2, # 上下文窗口大小
"iterations": 3, # 训练迭代次数
"random_seed": 3, # 随机种子
}
)
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# 文本嵌入配置
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型
embedding_batch_num: int = 32 # 批处理大小
embedding_func_max_async: int = 16 # 最大并发请求数
# LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
# 语言模型配置
llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称
llm_model_max_token_size: int = 32768 # 模型最大token限制
llm_model_max_async: int = 16 # 最大并发请求数
llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数
# storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
enable_llm_cache: bool = True
# 存储配置
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数
enable_llm_cache: bool = True # 是否启用语言模型缓存
# extension
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
# 扩展配置
addon_params: dict = field(default_factory=dict) # 附加参数
convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数
def __post_init__(self):
"""
初始化方法在对象创建后自动调用
负责设置日志初始化存储系统和配置各种功能组件
"""
# 配置日志系统
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
# 记录初始化参数
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# @TODO: should move all storage setup here to leverage initial start params attached to self.
# 根据配置选择图存储实现类
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.kg
]
# 确保工作目录存在
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
# 初始化文档存储系统
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self)
)
# 初始化文本块存储系统
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self)
)
# 初始化语言模型响应缓存(如果启用)
self.llm_response_cache = (
self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self)
@ -150,32 +181,40 @@ class LightRAG:
if self.enable_llm_cache
else None
)
# 初始化实体关系图存储
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self)
)
# 配置嵌入函数的并发限制
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
# 初始化向量数据库存储系统
# 用于存储实体的向量表示
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
# 用于存储关系的向量表示
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
# 用于存储文本块的向量表示
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
# 配置语言模型函数的并发限制和缓存
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func,
@ -185,33 +224,62 @@ class LightRAG:
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
"""
获取图存储类的实现
根据配置选择合适的图存储后端Neo4J或NetworkX
返回值:
Type[BaseGraphStorage]: 图存储类
"""
return {
"Neo4JStorage": Neo4JStorage,
"NetworkXStorage": NetworkXStorage,
}
def insert(self, string_or_strings):
"""
同步方式插入文档
将字符串或字符串列表插入到系统中进行处理
参数:
string_or_strings: 单个字符串或字符串列表表示要处理的文档内容
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
"""
异步方式插入文档
处理文档内容包括分块实体提取和向量化存储
参数:
string_or_strings: 单个字符串或字符串列表表示要处理的文档内容
"""
try:
# 确保输入是列表形式
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 为每个文档生成唯一ID并创建文档字典
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
}
# 过滤掉已存在的文档
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
# 处理文档分块
inserting_chunks = {}
for doc_key, doc in new_docs.items():
# 对每个文档进行分块并为每个块生成唯一ID
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
@ -225,19 +293,25 @@ class LightRAG:
)
}
inserting_chunks.update(chunks)
# 过滤掉已存在的文本块
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
# 将新的文本块插入向量数据库
await self.chunks_vdb.upsert(inserting_chunks)
# 提取实体和关系并更新知识图谱
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
@ -246,38 +320,70 @@ class LightRAG:
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# 更新文档和文本块存储
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
# 完成插入后执行清理工作
await self._insert_done()
async def _insert_done(self):
"""
插入操作完成后的回调函数
负责更新所有存储实例的索引并保存状态
"""
tasks = []
# 遍历所有需要执行回调的存储实例
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.full_docs, # 完整文档存储
self.text_chunks, # 文本块存储
self.llm_response_cache, # LLM响应缓存
self.entities_vdb, # 实体向量数据库
self.relationships_vdb, # 关系向量数据库
self.chunks_vdb, # 文本块向量<E59091><E9878F><EFBFBD>据库
self.chunk_entity_relation_graph, # 实体关系图
]:
if storage_inst is None:
continue
# 将每个存储实例的回调任务添加到任务列表
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks)
def query(self, query: str, param: QueryParam = QueryParam()):
"""
同步方式执行查询
参数:
query: 查询文本
param: 查询参数配置
返回值:
查询结果
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
"""
异步方式执行查询
支持多种查询模式本地查询全局查询混合查询和简单查询
参数:
query: 查询文本
param: 查询参数配置
返回值:
查询结果
"""
# 根据查询模式选择相应的查询方法
if param.mode == "local":
# 本地查询:主要基于局部上下文
response = await local_query(
query,
self.chunk_entity_relation_graph,
@ -288,6 +394,7 @@ class LightRAG:
asdict(self),
)
elif param.mode == "global":
# 全局查询:考虑整个知识图谱
response = await global_query(
query,
self.chunk_entity_relation_graph,
@ -298,6 +405,7 @@ class LightRAG:
asdict(self),
)
elif param.mode == "hybrid":
# 混合查询:结合局部和全局信息
response = await hybrid_query(
query,
self.chunk_entity_relation_graph,
@ -308,6 +416,7 @@ class LightRAG:
asdict(self),
)
elif param.mode == "naive":
# 简单查询:直接基于文本相似度
response = await naive_query(
query,
self.chunks_vdb,
@ -317,38 +426,73 @@ class LightRAG:
)
else:
raise ValueError(f"Unknown mode {param.mode}")
# 执行查询完成后的清理工作
await self._query_done()
return response
async def _query_done(self):
"""
查询操作完成后的回调函数
主要用于更新LLM响应缓存的状态
"""
tasks = []
# 目前只需要处理LLM响应缓存的回调
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks)
def delete_by_entity(self, entity_name: str):
"""
同步方式删除指定实体
参数:
entity_name: 要删除的实体名称
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name))
async def adelete_by_entity(self, entity_name: str):
"""
异步方式删除指定实体及其相关的所有信息
参数:
entity_name: 要删除的实体名称
"""
# 标准化实体名称(转为大写并添加引号)
entity_name = f'"{entity_name.upper()}"'
try:
# 依次删除实体在各个存储中的数据:
# 1. 从实体向量数据库中删除
await self.entities_vdb.delete_entity(entity_name)
# 2. 从关系向量数据库中删除相关关系
await self.relationships_vdb.delete_relation(entity_name)
# 3. 从知识图谱中删除节点
await self.chunk_entity_relation_graph.delete_node(entity_name)
# 记录删除成功的日志
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
# 执行删除完成后的清理工作
await self._delete_by_entity_done()
except Exception as e:
# 记录删除过程中的错误
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(self):
"""
实体删除操作完成后的回调函数
负责更新所有相关存储实例的状态
"""
tasks = []
# 遍历需要更新的存储实例:
# - 实体向量数据库
# - 关系向量数据库
# - 实体关系图
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
@ -356,5 +500,11 @@ class LightRAG:
]:
if storage_inst is None:
continue
# 将每个存储实例的回调添加到任务列表
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks)

View File

@ -31,9 +31,10 @@ from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
# 禁用并行化以避免tokenizers的并行化导致的问题
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 使用retry装饰器处理重试逻辑处理OpenAI API的速率限制、连接和超时错误
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@ -48,35 +49,64 @@ async def openai_complete_if_cache(
api_key=None,
**kwargs,
) -> str:
"""
异步函数通过OpenAI的API获取语言模型的补全结果支持缓存机制
参数:
- model: 使用的模型名称
- prompt: 用户输入的提示
- system_prompt: 系统提示可选
- history_messages: 历史消息可选
- base_url: API的基础URL可选
- api_key: API密钥可选
- **kwargs: 其他参数
返回:
- str: 模型生成的文本
"""
# 设置环境变量中的API密钥
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
# 初始化OpenAI异步客户端
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
# 初始化哈希存储和消息列表
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
# 添加系统提示到消息列表
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# 将历史消息和当前提示添加到消息列表
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 检查缓存中是否有结果
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# 调用OpenAI API获取补全结果
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
# 将结果缓存
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
# 返回生成的文本
return response.choices[0].message.content
# 与openai_complete_if_cache类似的函数但用于Azure OpenAI服务
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@ -91,45 +121,71 @@ async def azure_openai_complete_if_cache(
api_key=None,
**kwargs,
):
"""
异步函数通过Azure OpenAI的API获取语言模型的补全结果支持缓存机制
参数:
- model: 使用的模型名称
- prompt: 用户输入的提示
- system_prompt: 系统提示可选
- history_messages: 历史消息可选
- base_url: API的基础URL可选
- api_key: API密钥可选
- **kwargs: 其他参数
返回:
- str: 模型生成的文本
"""
# 设置环境变量中的API密钥和端点
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
# 初始化Azure OpenAI异步客户端
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
# 初始化哈希存储和消息列表
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
# 添加系统提示到消息列表
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# 将历史消息和当前提示添加到消息列表
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
# 检查缓存中是否有结果
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# 调用Azure OpenAI API获取补全结果
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
# 将结果缓存
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
# 返回生成的文本
return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
"""Amazon Bedrock 相关问题的通用错误"""
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
@ -145,6 +201,25 @@ async def bedrock_complete_if_cache(
aws_session_token=None,
**kwargs,
) -> str:
"""
异步使用 Amazon Bedrock 完成文本生成支持缓存
如果缓存命中则直接返回缓存结果该函数在失败时支持重试
参数:
- model: 要使用的 Bedrock 模型的模型 ID
- prompt: 用户输入的提示
- system_prompt: 系统提示如果有
- history_messages: 会话历史消息列表用于对话上下文
- aws_access_key_id: AWS 访问密钥 ID
- aws_secret_access_key: AWS 秘密访问密钥
- aws_session_token: AWS 会话令牌
- **kwargs: 其他参数例如推理参数
返回:
- str: 生成的文本结果
"""
# 设置 AWS 凭证
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
@ -155,24 +230,24 @@ async def bedrock_complete_if_cache(
"AWS_SESSION_TOKEN", aws_session_token
)
# Fix message history format
# 修复消息历史记录格式
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
# 添加用户提示
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
# 初始化 Converse API 参数
args = {"modelId": model, "messages": messages}
# Define system prompt
# 定义系统提示
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
# 映射并设置推理参数
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
@ -187,6 +262,7 @@ async def bedrock_complete_if_cache(
kwargs.pop(param)
)
# 处理缓存逻辑
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
@ -194,7 +270,7 @@ async def bedrock_complete_if_cache(
if if_cache_return is not None:
return if_cache_return["return"]
# Call model via Converse API
# 通过 Converse API 调用模型
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
@ -202,6 +278,7 @@ async def bedrock_complete_if_cache(
except Exception as e:
raise BedrockError(e)
# 更新缓存(如果启用)
if hashing_kv is not None:
await hashing_kv.upsert(
{
@ -214,9 +291,20 @@ async def bedrock_complete_if_cache(
return response["output"]["message"]["content"][0]["text"]
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
"""
初始化Hugging Face模型和tokenizer
使用指定的模型名称初始化模型和tokenizer并根据需要设置padding token
参数:
- model_name: 模型的名称
返回:
- hf_model: 初始化的Hugging Face模型
- hf_tokenizer: 初始化的Hugging Face tokenizer
"""
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
@ -232,6 +320,21 @@ def initialize_hf_model(model_name):
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用缓存的Hugging Face模型进行推理
如果缓存中存在相同的输入则直接返回结果否则使用指定的模型进行推理并将结果缓存
参数:
- model: 模型的名称
- prompt: 用户的输入提示
- system_prompt: 系统的提示可选
- history_messages: 历史消息列表可选
- **kwargs: 其他关键字参数例如hashing_kv用于缓存存储
返回:
- response_text: 模型的响应文本
"""
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
@ -297,32 +400,58 @@ async def hf_model_if_cache(
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
异步函数通过Olama模型生成回答支持缓存机制以优化性能
参数:
model: 使用的模型名称
prompt: 用户的提问
system_prompt: 系统的提示用于设定对话背景
history_messages: 历史对话消息用于维持对话上下文
**kwargs: 其他参数包括max_tokens, response_format, host, timeout等
返回:
生成的模型回答
"""
# 移除不需要的参数以符合Olama客户端的期望
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
# 初始化Olama异步客户端
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
# 构建消息列表,首先添加系统提示(如果有)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# 获取哈希存储实例,用于缓存
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# 将历史消息和当前用户提问添加到消息列表
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 如果提供了哈希存储,尝试从缓存中获取回答
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# 如果缓存中没有回答调用Olama模型生成回答
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
# 提取生成的回答内容
result = response["message"]["content"]
# 如果使用了哈希存储,将新生成的回答存入缓存
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
# 返回生成的回答
return result
@ -335,8 +464,24 @@ def initialize_lmdeploy_pipeline(
model_format="hf",
quant_policy=0,
):
"""
初始化lmdeploy管道用于模型推理带有缓存机制
参数:
model: 模型路径
tp: 张量并行度
chat_template: 聊天模板配置
log_level: 日志级别
model_format: 模型格式
quant_policy: 量化策略
返回:
初始化的lmdeploy管道实例
"""
# 导入必要的模块和类
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
# 创建并配置lmdeploy管道
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
@ -347,6 +492,7 @@ def initialize_lmdeploy_pipeline(
else None,
log_level="WARNING",
)
# 返回配置好的管道实例
return lmdeploy_pipe
@ -361,39 +507,38 @@ async def lmdeploy_model_if_cache(
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
异步执行语言模型推理支持缓存
该函数初始化 lmdeploy 管道进行模型推理支持多种模型格式和量化策略它处理输入的提示文本系统提示和历史消息
并尝试从缓存中检索响应如果未命中缓存则生成响应并缓存结果以供将来使用
参数:
model (str): 模型路径
可以是以下选项之一
- i) 通过 `lmdeploy convert` 命令转换或从 ii) iii) 下载的本地 turbomind 模型目录路径
- ii) huggingface.co 上托管的 lmdeploy 量化模型的 model_id例如
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"lmdeploy/llama2-chat-70b-4bit"
- iii) huggingface.co 上托管的模型的 model_id例如
"internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ",
"baichuan-inc/Baichuan2-7B-Chat"
chat_template (str): 当模型是 huggingface.co 上的 PyTorch 模型时需要例如 "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" 以及当本地路径的模型名称与 HF 中的原始模型名称不匹配时
tp (int): 张量并行度
prompt (Union[str, List[str]]): 要完成的输入文本
do_preprocess (bool): 是否预处理消息默认为 True表示将应用 chat_template
skip_special_tokens (bool): 解码时是否移除特殊标记默认为 True
do_sample (bool): 是否使用采样否则使用贪心解码默认为 False表示使用贪心解码
"""
# 导入 lmdeploy 及相关模块,如果未安装则抛出错误
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
raise ImportError("请在初始化 lmdeploy 后端之前安装 lmdeploy。")
# 提取并处理关键字参数
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
@ -402,16 +547,18 @@ async def lmdeploy_model_if_cache(
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
# 检查 lmdeploy 版本兼容性,确保支持 do_sample 参数
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
"`do_sample` 参数在 lmdeploy v0.6.0 之前不受支持,当前使用的 lmdeploy 版本为 {}"
.format(lmdeploy.__version__)
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
# 初始化 lmdeploy 管道
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
@ -421,25 +568,31 @@ async def lmdeploy_model_if_cache(
log_level="WARNING",
)
# 构建消息列表
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# 获取哈希存储对象
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 尝试从缓存中获取响应
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# 配置生成参数
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
# 生成响应
response = ""
async for res in lmdeploy_pipe.generate(
messages,
@ -450,14 +603,29 @@ async def lmdeploy_model_if_cache(
):
response += res.response
# 缓存生成的响应
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
return response
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用GPT-4o模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await openai_complete_if_cache(
"gpt-4o",
prompt,
@ -466,10 +634,21 @@ async def gpt_4o_complete(
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用较小的GPT-4o模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
@ -478,10 +657,21 @@ async def gpt_4o_mini_complete(
**kwargs,
)
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用Azure上的OpenAI模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
prompt,
@ -490,10 +680,21 @@ async def azure_openai_complete(
**kwargs,
)
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用Bedrock平台的特定模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
@ -502,10 +703,21 @@ async def bedrock_complete(
**kwargs,
)
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用Hugging Face模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数包括模型名称
返回:
- 生成的文本结果
"""
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
model_name,
@ -515,10 +727,21 @@ async def hf_model_complete(
**kwargs,
)
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用Ollama模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数包括模型名称
返回:
- 生成的文本结果
"""
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,
@ -529,7 +752,9 @@ async def ollama_model_complete(
)
# 使用装饰器添加属性,如嵌入维度和最大令牌大小
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
# 使用重试机制处理可能的速率限制、API连接和超时错误
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@ -541,6 +766,18 @@ async def openai_embedding(
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
"""
使用OpenAI模型生成文本嵌入
参数:
- texts: 需要生成嵌入的文本列表
- model: 使用的模型名称
- base_url: API的基础URL
- api_key: API密钥
返回:
- 嵌入的NumPy数组
"""
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
@ -552,8 +789,9 @@ async def openai_embedding(
)
return np.array([dp.embedding for dp in response.data])
# 使用装饰器添加属性,如嵌入维度和最大令牌大小
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
# 使用重试机制处理可能的速率限制、API连接和超时错误
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@ -565,6 +803,18 @@ async def azure_openai_embedding(
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
"""
使用Azure OpenAI模型生成文本嵌入
参数:
- texts: 需要生成嵌入的文本列表
- model: 使用的模型名称
- base_url: API的基础URL
- api_key: API密钥
返回:
- 嵌入的NumPy数组
"""
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
@ -581,7 +831,7 @@ async def azure_openai_embedding(
)
return np.array([dp.embedding for dp in response.data])
# 使用重试机制处理可能的速率限制、API连接和超时错误
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@ -594,6 +844,19 @@ async def siliconcloud_embedding(
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
"""
使用SiliconCloud模型生成文本嵌入
参数:
- texts: 需要生成嵌入的文本列表
- model: 使用的模型名称
- base_url: API的基础URL
- max_token_size: 最大令牌大小
- api_key: API密钥
返回:
- 嵌入的NumPy数组
"""
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
@ -633,6 +896,22 @@ async def bedrock_embedding(
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
"""
生成给定文本的嵌入向量
使用指定的模型对文本列表进行嵌入处理支持Amazon Bedrock和Cohere模型
参数:
- texts: 需要嵌入的文本列表
- model: 使用的模型标识符默认为"amazon.titan-embed-text-v2:0"
- aws_access_key_id: AWS访问密钥ID
- aws_secret_access_key: AWS秘密访问密钥
- aws_session_token: AWS会话令牌
返回:
- 嵌入向量的NumPy数组
"""
# 设置AWS环境变量
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
@ -643,11 +922,14 @@ async def bedrock_embedding(
"AWS_SESSION_TOKEN", aws_session_token
)
# 创建aioboto3会话
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
# 根据模型提供者进行不同的处理
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
# 根据模型版本构建请求体
if "v2" in model:
body = json.dumps(
{
@ -661,6 +943,7 @@ async def bedrock_embedding(
else:
raise ValueError(f"Model {model} is not supported!")
# 调用Bedrock模型
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
@ -693,9 +976,22 @@ async def bedrock_embedding(
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
"""
使用Hugging Face模型生成给定文本的嵌入向量
参数:
- texts: 需要嵌入的文本列表
- tokenizer: Hugging Face的标记器实例
- embed_model: Hugging Face的嵌入模型实例
返回:
- 嵌入向量的NumPy数组
"""
# 对文本进行标记化处理
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids
# 使用模型生成嵌入向量
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
@ -703,9 +999,22 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
使用Ollama模型生成给定文本的嵌入向量
参数:
- texts: 需要嵌入的文本列表
- embed_model: 使用的嵌入模型标识符
- **kwargs: 传递给Ollama客户端的其他参数
返回:
- 嵌入向量的列表
"""
embed_text = []
# 创建Ollama客户端实例
ollama_client = ollama.Client(**kwargs)
for text in texts:
# 调用模型生成嵌入向量
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
@ -798,4 +1107,4 @@ if __name__ == "__main__":
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())
asyncio.run(main())

File diff suppressed because it is too large Load Diff

View File

@ -20,27 +20,63 @@ from .base import (
BaseVectorStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
"""
基于JSON文件的键值存储实现类
继承自BaseKVStorage提供基本的键值存储功能
数据以JSON格式保存在文件系统中
"""
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
"""
初始化方法在对象创建后自动调用
- 设置工作目录和文件路径
- 加载已存在的JSON数据
"""
working_dir = self.global_config["working_dir"] # 从全局配置获取工作目录
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") # 构建JSON文件完整路径
self._data = load_json(self._file_name) or {} # 加载JSON文件如果不存在则初始化为空字典
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") # 记录加载数据的数量
async def all_keys(self) -> list[str]:
"""
获取存储中的所有键
返回值
list[str]: 包含所有键的列表
"""
return list(self._data.keys())
async def index_done_callback(self):
"""
索引完成后的回调函数
将当前内存中的数据写入JSON文件
"""
write_json(self._data, self._file_name)
async def get_by_id(self, id):
"""
通过ID获取单个数据
参数
id: 要查询的数据ID
返回值
查找到的数据如果不存在则返回None
"""
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
"""
批量获取多个ID的数据
参数
ids: ID列表
fields: 可选指定要返回的字段列表
返回值
list: 包含查询结果的列表每个元素对应一个ID的数据
"""
if fields is None:
# 如果未指定字段,返回完整数据
return [self._data.get(id, None) for id in ids]
# 如果指定了字段,只返回指定的字段
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
@ -51,38 +87,80 @@ class JsonKVStorage(BaseKVStorage):
]
async def filter_keys(self, data: list[str]) -> set[str]:
"""
过滤出不存在于存储中的键
参数
data: 要检查的键列表
返回值
set[str]: 不存在的键集合
"""
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
return left_data
"""
更新或插入数据
参数
data: 要更新/插入的数据字典格式为 {id: {字段: }}
返回值
dict: 实际插入的新数据不包含更新的数据
"""
left_data = {k: v for k, v in data.items() if k not in self._data} # 筛选出新数据
self._data.update(left_data) # 更新存储
return left_data # 返回新插入的数据
async def drop(self):
"""
清空所有数据
将内存中的数据字典重置为空
"""
self._data = {}
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
"""
向量数据库存储实现类
基于NanoVectorDB实现向量存储和检索功能
支持向量的增删改查操作
"""
# 余弦相似度阈值,用于过滤搜索结果
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
"""
初始化方法在对象创建后自动调用
设置存储文件路径批处理大小并初始化向量数据库客户端
"""
# 构建向量数据库存储文件路径
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
# 设置批处理大小
self._max_batch_size = self.global_config["embedding_batch_num"]
# 初始化向量数据库客户端
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
# 从配置中获取相似度阈值
self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
"""
更新或插入向量数据
参数:
data: 包含向量数据的字典格式为 {id: {字段: }}
返回值:
list: 插入结果
"""
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
# 准备数据,提取元数据字段
list_data = [
{
"__id__": k,
@ -90,28 +168,49 @@ class NanoVectorDBStorage(BaseVectorStorage):
}
for k, v in data.items()
]
# 提取内容并分批处理
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
# 并行计算向量嵌入
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
# 将向量添加到数据中
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# 执行更新/插入操作
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5):
"""
查询最相似的向量
参数:
query: 查询文本
top_k: 返回的最相似结果数量
返回值:
list: 包含相似度结果的列表
"""
# 计算查询文本的向量表示
embedding = await self.embedding_func([query])
embedding = embedding[0]
# 执行向量检索
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
# 格式化返回结果
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
@ -119,12 +218,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
@property
def client_storage(self):
"""获取底层存储对象"""
return getattr(self._client, "_NanoVectorDB__storage")
async def delete_entity(self, entity_name: str):
"""
删除指定实体
参数:
entity_name: 要删除的实体名称
"""
try:
# 计算实体ID
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
# 检查并删除实体
if self._client.get(entity_id):
self._client.delete(entity_id)
logger.info(f"Entity {entity_name} have been deleted.")
@ -134,7 +241,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error while deleting entity {entity_name}: {e}")
async def delete_relation(self, entity_name: str):
"""
删除与指定实体相关的所有关系
参数:
entity_name: 实体名称
"""
try:
# 查找所有相关关系
relations = [
dp
for dp in self.client_storage["data"]
@ -142,6 +255,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
]
ids_to_delete = [relation["__id__"] for relation in relations]
# 执行删除操作
if ids_to_delete:
self._client.delete(ids_to_delete)
logger.info(
@ -155,19 +269,38 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
async def index_done_callback(self):
"""索引完成后的回调函数,保存数据到存储文件"""
self._client.save()
@dataclass
class NetworkXStorage(BaseGraphStorage):
"""
基于NetworkX的图存储实现类
提供图数据的存储读取和操作功能
"""
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
"""
从文件加载图数据
参数:
file_name: GraphML文件路径
返回值:
nx.Graph: 加载的图对象如果文件不存在返回None
"""
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
"""
将图数据写入文件
参数:
graph: 要保存的图对象
file_name: 保存路径
"""
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
@ -175,40 +308,51 @@ class NetworkXStorage(BaseGraphStorage):
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
获取图的最大连通分量并确保节点和边的顺序稳定
参考: https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
参数:
graph: 输入图
返回值:
nx.Graph: 处理后的稳定图
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
# 对节点标签进行标准化处理
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
}
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
确保无向图的关系始终以相同的方式读取
参数:
graph: 输入图
返回值:
nx.Graph: 稳定化后的图
"""
# 根据图的类型创建新图
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
# 对节点进行排序
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
# 添加排序后的节点
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
# 对于无向图,确保边的源节点和目标节点有固定顺序
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
source, target = target, source
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
@ -216,12 +360,18 @@ class NetworkXStorage(BaseGraphStorage):
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
# 对边进行排序
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
"""
初始化方法
- 设置图存储文件路径
- 加载已存在的图数据
- 初始化节点嵌入算法
"""
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
@ -236,46 +386,56 @@ class NetworkXStorage(BaseGraphStorage):
}
async def index_done_callback(self):
"""索引完成后的回调,保存图数据到文件"""
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
"""检查节点是否存在"""
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""检查边是否存在"""
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
"""获取节点数据"""
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
"""获取节点的度"""
return self._graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""获取边的度(源节点度 + 目标节点度)"""
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""获取边的数据"""
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
"""获取节点的所有边"""
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""更新或插入节点"""
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
"""更新或插入边"""
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
删除指定的节点
参数:
node_id: 要删除的节点ID
"""
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
@ -284,12 +444,23 @@ class NetworkXStorage(BaseGraphStorage):
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""
使用指定算法进行节点嵌入
参数:
algorithm: 嵌入算法名称
返回值:
tuple: (嵌入向量数组, 节点ID列表)
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
async def _node2vec_embed(self):
"""
使用node2vec算法进行节点嵌入未使用
返回值:
tuple: (嵌入向量数组, 节点ID列表)
"""
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
@ -298,4 +469,4 @@ class NetworkXStorage(BaseGraphStorage):
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
return embeddings, nodes_ids