lightrag.py使用通义灵码加注释。 其他三个文件使用cursor加注释。
This commit is contained in:
parent
c0fa4da53d
commit
c8ee7286cb
@ -1,3 +1,8 @@
|
||||
"""
|
||||
LightRAG - 轻量级检索增强生成系统
|
||||
该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
@ -5,144 +10,170 @@ from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Type, cast
|
||||
|
||||
# 导入LLM相关功能
|
||||
from .llm import (
|
||||
gpt_4o_mini_complete,
|
||||
openai_embedding,
|
||||
gpt_4o_mini_complete, # GPT模型完成功能
|
||||
openai_embedding, # OpenAI文本嵌入功能
|
||||
)
|
||||
|
||||
# 导入核心操作功能
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
local_query,
|
||||
global_query,
|
||||
hybrid_query,
|
||||
naive_query,
|
||||
chunking_by_token_size, # 文本分块
|
||||
extract_entities, # 实体提取
|
||||
local_query, # 本地查询
|
||||
global_query, # 全局查询
|
||||
hybrid_query, # 混合查询
|
||||
naive_query, # 简单查询
|
||||
)
|
||||
|
||||
# 导入存储实现
|
||||
from .storage import (
|
||||
JsonKVStorage,
|
||||
NanoVectorDBStorage,
|
||||
NetworkXStorage,
|
||||
JsonKVStorage, # JSON键值存储
|
||||
NanoVectorDBStorage, # 向量数据库存储
|
||||
NetworkXStorage, # 图数据库存储
|
||||
)
|
||||
|
||||
from .kg.neo4j_impl import Neo4JStorage
|
||||
# future KG integrations
|
||||
|
||||
from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现
|
||||
# 未来可能的图数据库集成
|
||||
# from .kg.ArangoDB_impl import (
|
||||
# GraphStorage as ArangoDBStorage
|
||||
# )
|
||||
|
||||
|
||||
# 导入工具函数
|
||||
from .utils import (
|
||||
EmbeddingFunc,
|
||||
compute_mdhash_id,
|
||||
limit_async_func_call,
|
||||
convert_response_to_json,
|
||||
logger,
|
||||
set_logger,
|
||||
EmbeddingFunc, # 嵌入函数类型
|
||||
compute_mdhash_id, # 计算MD5哈希ID
|
||||
limit_async_func_call, # 限制异步函数调用
|
||||
convert_response_to_json, # 响应转JSON
|
||||
logger, # 日志记录器
|
||||
set_logger, # 设置日志
|
||||
)
|
||||
|
||||
# 导入基类
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
StorageNameSpace,
|
||||
QueryParam,
|
||||
BaseGraphStorage, # 图存储基类
|
||||
BaseKVStorage, # 键值存储基类
|
||||
BaseVectorStorage, # 向量存储基类
|
||||
StorageNameSpace, # 存储命名空间
|
||||
QueryParam, # 查询参数
|
||||
)
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
获取或创建事件循环
|
||||
如果当前线程没有事件循环,则创建一个新的
|
||||
返回值:
|
||||
asyncio.AbstractEventLoop: 事件循环实例
|
||||
"""
|
||||
try:
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
except RuntimeError:
|
||||
logger.info("Creating a new event loop in main thread.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop
|
||||
|
||||
|
||||
@dataclass
|
||||
class LightRAG:
|
||||
"""
|
||||
轻量级检索增强生成(LightRAG)系统的主类
|
||||
实现了文档的存储、检索、知识图谱构建和问答功能
|
||||
"""
|
||||
|
||||
# 工作目录配置,用存储所有缓存文件
|
||||
working_dir: str = field(
|
||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
|
||||
# 知识图谱存储类型,默认使用NetworkX实现
|
||||
kg: str = field(default="NetworkXStorage")
|
||||
|
||||
# 日志级别设置
|
||||
current_log_level = logger.level
|
||||
log_level: str = field(default=current_log_level)
|
||||
|
||||
# text chunking
|
||||
chunk_token_size: int = 1200
|
||||
chunk_overlap_token_size: int = 100
|
||||
tiktoken_model_name: str = "gpt-4o-mini"
|
||||
# 文本分块参数配置
|
||||
chunk_token_size: int = 1200 # 每个文本块的目标token数
|
||||
chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数
|
||||
tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称
|
||||
|
||||
# entity extraction
|
||||
entity_extract_max_gleaning: int = 1
|
||||
entity_summary_to_max_tokens: int = 500
|
||||
# 实体提取参数
|
||||
entity_extract_max_gleaning: int = 1 # 最大实体提取次数
|
||||
entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数
|
||||
|
||||
# node embedding
|
||||
node_embedding_algorithm: str = "node2vec"
|
||||
# 节点嵌入配置
|
||||
node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择
|
||||
node2vec_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"dimensions": 1536,
|
||||
"num_walks": 10,
|
||||
"walk_length": 40,
|
||||
"window_size": 2,
|
||||
"iterations": 3,
|
||||
"random_seed": 3,
|
||||
"dimensions": 1536, # 嵌入向量维度
|
||||
"num_walks": 10, # 每个节点的随机游走次数
|
||||
"walk_length": 40, # 每次随机游走的长度
|
||||
"window_size": 2, # 上下文窗口大小
|
||||
"iterations": 3, # 训练迭代次数
|
||||
"random_seed": 3, # 随机种子
|
||||
}
|
||||
)
|
||||
|
||||
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
# 文本嵌入配置
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型
|
||||
embedding_batch_num: int = 32 # 批处理大小
|
||||
embedding_func_max_async: int = 16 # 最大并发请求数
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
llm_model_kwargs: dict = field(default_factory=dict)
|
||||
# 语言模型配置
|
||||
llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称
|
||||
llm_model_max_token_size: int = 32768 # 模型最大token限制
|
||||
llm_model_max_async: int = 16 # 最大并发请求数
|
||||
llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数
|
||||
|
||||
# storage
|
||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
enable_llm_cache: bool = True
|
||||
# 存储配置
|
||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类
|
||||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数
|
||||
enable_llm_cache: bool = True # 是否启用语言模型缓存
|
||||
|
||||
# extension
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
convert_response_to_json_func: callable = convert_response_to_json
|
||||
# 扩展配置
|
||||
addon_params: dict = field(default_factory=dict) # 附加参数
|
||||
convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
初始化方法,在对象创建后自动调用
|
||||
负责设置日志、初始化存储系统和配置各种功能组件
|
||||
"""
|
||||
# 配置日志系统
|
||||
log_file = os.path.join(self.working_dir, "lightrag.log")
|
||||
set_logger(log_file)
|
||||
logger.setLevel(self.log_level)
|
||||
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
|
||||
# 记录初始化参数
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
||||
# 根据配置选择图存储实现类
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
||||
self.kg
|
||||
]
|
||||
|
||||
# 确保工作目录存在
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
# 初始化文档存储系统
|
||||
self.full_docs = self.key_string_value_json_storage_cls(
|
||||
namespace="full_docs", global_config=asdict(self)
|
||||
)
|
||||
|
||||
# 初始化文本块存储系统
|
||||
self.text_chunks = self.key_string_value_json_storage_cls(
|
||||
namespace="text_chunks", global_config=asdict(self)
|
||||
)
|
||||
|
||||
# 初始化语言模型响应缓存(如果启用)
|
||||
self.llm_response_cache = (
|
||||
self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache", global_config=asdict(self)
|
||||
@ -150,32 +181,40 @@ class LightRAG:
|
||||
if self.enable_llm_cache
|
||||
else None
|
||||
)
|
||||
|
||||
# 初始化实体关系图存储
|
||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||
namespace="chunk_entity_relation", global_config=asdict(self)
|
||||
)
|
||||
|
||||
# 配置嵌入函数的并发限制
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
# 初始化向量数据库存储系统
|
||||
# 用于存储实体的向量表示
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
# 用于存储关系的向量表示
|
||||
self.relationships_vdb = self.vector_db_storage_cls(
|
||||
namespace="relationships",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
# 用于存储文本块的向量表示
|
||||
self.chunks_vdb = self.vector_db_storage_cls(
|
||||
namespace="chunks",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
# 配置语言模型函数的并发限制和缓存
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
self.llm_model_func,
|
||||
@ -185,33 +224,62 @@ class LightRAG:
|
||||
)
|
||||
|
||||
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
||||
"""
|
||||
获取图存储类的实现
|
||||
根据配置选择合适的图存储后端(Neo4J或NetworkX)
|
||||
|
||||
返回值:
|
||||
Type[BaseGraphStorage]: 图存储类
|
||||
"""
|
||||
return {
|
||||
"Neo4JStorage": Neo4JStorage,
|
||||
"NetworkXStorage": NetworkXStorage,
|
||||
}
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
"""
|
||||
同步方式插入文档
|
||||
将字符串或字符串列表插入到系统中进行处理
|
||||
|
||||
参数:
|
||||
string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
"""
|
||||
异步方式插入文档
|
||||
处理文档内容,包括分块、实体提取和向量化存储
|
||||
|
||||
参数:
|
||||
string_or_strings: 单个字符串或字符串列表,表示要处理的文档内容
|
||||
"""
|
||||
try:
|
||||
# 确保输入是列表形式
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
|
||||
# 为每个文档生成唯一ID并创建文档字典
|
||||
new_docs = {
|
||||
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
||||
for c in string_or_strings
|
||||
}
|
||||
|
||||
# 过滤掉已存在的文档
|
||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
|
||||
if not len(new_docs):
|
||||
logger.warning("All docs are already in the storage")
|
||||
return
|
||||
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
# 处理文档分块
|
||||
inserting_chunks = {}
|
||||
for doc_key, doc in new_docs.items():
|
||||
# 对每个文档进行分块,并为每个块生成唯一ID
|
||||
chunks = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
@ -225,19 +293,25 @@ class LightRAG:
|
||||
)
|
||||
}
|
||||
inserting_chunks.update(chunks)
|
||||
|
||||
# 过滤掉已存在的文本块
|
||||
_add_chunk_keys = await self.text_chunks.filter_keys(
|
||||
list(inserting_chunks.keys())
|
||||
)
|
||||
inserting_chunks = {
|
||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||
}
|
||||
|
||||
if not len(inserting_chunks):
|
||||
logger.warning("All chunks are already in the storage")
|
||||
return
|
||||
|
||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||
|
||||
# 将新的文本块插入向量数据库
|
||||
await self.chunks_vdb.upsert(inserting_chunks)
|
||||
|
||||
# 提取实体和关系并更新知识图谱
|
||||
logger.info("[Entity Extraction]...")
|
||||
maybe_new_kg = await extract_entities(
|
||||
inserting_chunks,
|
||||
@ -246,38 +320,70 @@ class LightRAG:
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
|
||||
if maybe_new_kg is None:
|
||||
logger.warning("No new entities and relationships found")
|
||||
return
|
||||
|
||||
self.chunk_entity_relation_graph = maybe_new_kg
|
||||
|
||||
# 更新文档和文本块存储
|
||||
await self.full_docs.upsert(new_docs)
|
||||
await self.text_chunks.upsert(inserting_chunks)
|
||||
finally:
|
||||
# 完成插入后执行清理工作
|
||||
await self._insert_done()
|
||||
|
||||
async def _insert_done(self):
|
||||
"""
|
||||
插入操作完成后的回调函数
|
||||
负责更新所有存储实例的索引并保存状态
|
||||
"""
|
||||
tasks = []
|
||||
# 遍历所有需要执行回调的存储实例
|
||||
for storage_inst in [
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.full_docs, # 完整文档存储
|
||||
self.text_chunks, # 文本块存储
|
||||
self.llm_response_cache, # LLM响应缓存
|
||||
self.entities_vdb, # 实体向量数据库
|
||||
self.relationships_vdb, # 关系向量数据库
|
||||
self.chunks_vdb, # 文本块向量<E59091><E9878F><EFBFBD>据库
|
||||
self.chunk_entity_relation_graph, # 实体关系图
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
# 将每个存储实例的回调任务添加到任务列表
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
# 并发执行所有回调任务
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def query(self, query: str, param: QueryParam = QueryParam()):
|
||||
"""
|
||||
同步方式执行查询
|
||||
|
||||
参数:
|
||||
query: 查询文本
|
||||
param: 查询参数配置
|
||||
返回值:
|
||||
查询结果
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, param))
|
||||
|
||||
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
||||
"""
|
||||
异步方式执行查询
|
||||
支持多种查询模式:本地查询、全局查询、混合查询和简单查询
|
||||
|
||||
参数:
|
||||
query: 查询文本
|
||||
param: 查询参数配置
|
||||
返回值:
|
||||
查询结果
|
||||
"""
|
||||
# 根据查询模式选择相应的查询方法
|
||||
if param.mode == "local":
|
||||
# 本地查询:主要基于局部上下文
|
||||
response = await local_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
@ -288,6 +394,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "global":
|
||||
# 全局查询:考虑整个知识图谱
|
||||
response = await global_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
@ -298,6 +405,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "hybrid":
|
||||
# 混合查询:结合局部和全局信息
|
||||
response = await hybrid_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
@ -308,6 +416,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
# 简单查询:直接基于文本相似度
|
||||
response = await naive_query(
|
||||
query,
|
||||
self.chunks_vdb,
|
||||
@ -317,38 +426,73 @@ class LightRAG:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
||||
# 执行查询完成后的清理工作
|
||||
await self._query_done()
|
||||
return response
|
||||
|
||||
async def _query_done(self):
|
||||
"""
|
||||
查询操作完成后的回调函数
|
||||
主要用于更新LLM响应缓存的状态
|
||||
"""
|
||||
tasks = []
|
||||
# 目前只需要处理LLM响应缓存的回调
|
||||
for storage_inst in [self.llm_response_cache]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
# 并发执行所有回调任务
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def delete_by_entity(self, entity_name: str):
|
||||
"""
|
||||
同步方式删除指定实体
|
||||
|
||||
参数:
|
||||
entity_name: 要删除的实体名称
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
||||
|
||||
async def adelete_by_entity(self, entity_name: str):
|
||||
entity_name = f'"{entity_name.upper()}"'
|
||||
"""
|
||||
异步方式删除指定实体及其相关的所有信息
|
||||
|
||||
参数:
|
||||
entity_name: 要删除的实体名称
|
||||
"""
|
||||
# 标准化实体名称(转为大写并添加引号)
|
||||
entity_name = f'"{entity_name.upper()}"'
|
||||
try:
|
||||
# 依次删除实体在各个存储中的数据:
|
||||
# 1. 从实体向量数据库中删除
|
||||
await self.entities_vdb.delete_entity(entity_name)
|
||||
# 2. 从关系向量数据库中删除相关关系
|
||||
await self.relationships_vdb.delete_relation(entity_name)
|
||||
# 3. 从知识图谱中删除节点
|
||||
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
||||
|
||||
# 记录删除成功的日志
|
||||
logger.info(
|
||||
f"Entity '{entity_name}' and its relationships have been deleted."
|
||||
)
|
||||
# 执行删除完成后的清理工作
|
||||
await self._delete_by_entity_done()
|
||||
except Exception as e:
|
||||
# 记录删除过程中的错误
|
||||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||||
|
||||
async def _delete_by_entity_done(self):
|
||||
"""
|
||||
实体删除操作完成后的回调函数
|
||||
负责更新所有相关存储实例的状态
|
||||
"""
|
||||
tasks = []
|
||||
# 遍历需要更新的存储实例:
|
||||
# - 实体向量数据库
|
||||
# - 关系向量数据库
|
||||
# - 实体关系图
|
||||
for storage_inst in [
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
@ -356,5 +500,11 @@ class LightRAG:
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
# 将每个存储实例的回调添加到任务列表
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
# 并发执行所有回调任务
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
|
||||
|
||||
|
403
lightrag/llm.py
403
lightrag/llm.py
@ -31,9 +31,10 @@ from typing import List, Dict, Callable, Any
|
||||
from .base import BaseKVStorage
|
||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
|
||||
# 禁用并行化以避免tokenizers的并行化导致的问题
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
# 使用retry装饰器处理重试逻辑,处理OpenAI API的速率限制、连接和超时错误
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@ -48,35 +49,64 @@ async def openai_complete_if_cache(
|
||||
api_key=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
异步函数,通过OpenAI的API获取语言模型的补全结果,支持缓存机制。
|
||||
|
||||
参数:
|
||||
- model: 使用的模型名称
|
||||
- prompt: 用户输入的提示
|
||||
- system_prompt: 系统提示(可选)
|
||||
- history_messages: 历史消息(可选)
|
||||
- base_url: API的基础URL(可选)
|
||||
- api_key: API密钥(可选)
|
||||
- **kwargs: 其他参数
|
||||
|
||||
返回:
|
||||
- str: 模型生成的文本
|
||||
"""
|
||||
# 设置环境变量中的API密钥
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
# 初始化OpenAI异步客户端
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
|
||||
# 初始化哈希存储和消息列表
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
|
||||
# 添加系统提示到消息列表
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 将历史消息和当前提示添加到消息列表
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 检查缓存中是否有结果
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# 调用OpenAI API获取补全结果
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
# 将结果缓存
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
||||
)
|
||||
|
||||
# 返回生成的文本
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
# 与openai_complete_if_cache类似的函数,但用于Azure OpenAI服务
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@ -91,45 +121,71 @@ async def azure_openai_complete_if_cache(
|
||||
api_key=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
异步函数,通过Azure OpenAI的API获取语言模型的补全结果,支持缓存机制。
|
||||
|
||||
参数:
|
||||
- model: 使用的模型名称
|
||||
- prompt: 用户输入的提示
|
||||
- system_prompt: 系统提示(可选)
|
||||
- history_messages: 历史消息(可选)
|
||||
- base_url: API的基础URL(可选)
|
||||
- api_key: API密钥(可选)
|
||||
- **kwargs: 其他参数
|
||||
|
||||
返回:
|
||||
- str: 模型生成的文本
|
||||
"""
|
||||
# 设置环境变量中的API密钥和端点
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
|
||||
# 初始化Azure OpenAI异步客户端
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
||||
# 初始化哈希存储和消息列表
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
|
||||
# 添加系统提示到消息列表
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 将历史消息和当前提示添加到消息列表
|
||||
messages.extend(history_messages)
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 检查缓存中是否有结果
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# 调用Azure OpenAI API获取补全结果
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
# 将结果缓存
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
||||
)
|
||||
|
||||
# 返回生成的文本
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
|
||||
"""Amazon Bedrock 相关问题的通用错误"""
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, max=60),
|
||||
@ -145,6 +201,25 @@ async def bedrock_complete_if_cache(
|
||||
aws_session_token=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
异步使用 Amazon Bedrock 完成文本生成,支持缓存。
|
||||
|
||||
如果缓存命中,则直接返回缓存结果。该函数在失败时支持重试。
|
||||
|
||||
参数:
|
||||
- model: 要使用的 Bedrock 模型的模型 ID。
|
||||
- prompt: 用户输入的提示。
|
||||
- system_prompt: 系统提示,如果有。
|
||||
- history_messages: 会话历史消息列表,用于对话上下文。
|
||||
- aws_access_key_id: AWS 访问密钥 ID。
|
||||
- aws_secret_access_key: AWS 秘密访问密钥。
|
||||
- aws_session_token: AWS 会话令牌。
|
||||
- **kwargs: 其他参数,例如推理参数。
|
||||
|
||||
返回:
|
||||
- str: 生成的文本结果。
|
||||
"""
|
||||
# 设置 AWS 凭证
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
@ -155,24 +230,24 @@ async def bedrock_complete_if_cache(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
# Fix message history format
|
||||
# 修复消息历史记录格式
|
||||
messages = []
|
||||
for history_message in history_messages:
|
||||
message = copy.copy(history_message)
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
messages.append(message)
|
||||
|
||||
# Add user prompt
|
||||
# 添加用户提示
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
|
||||
# Initialize Converse API arguments
|
||||
# 初始化 Converse API 参数
|
||||
args = {"modelId": model, "messages": messages}
|
||||
|
||||
# Define system prompt
|
||||
# 定义系统提示
|
||||
if system_prompt:
|
||||
args["system"] = [{"text": system_prompt}]
|
||||
|
||||
# Map and set up inference parameters
|
||||
# 映射并设置推理参数
|
||||
inference_params_map = {
|
||||
"max_tokens": "maxTokens",
|
||||
"top_p": "topP",
|
||||
@ -187,6 +262,7 @@ async def bedrock_complete_if_cache(
|
||||
kwargs.pop(param)
|
||||
)
|
||||
|
||||
# 处理缓存逻辑
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
@ -194,7 +270,7 @@ async def bedrock_complete_if_cache(
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# Call model via Converse API
|
||||
# 通过 Converse API 调用模型
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
try:
|
||||
@ -202,6 +278,7 @@ async def bedrock_complete_if_cache(
|
||||
except Exception as e:
|
||||
raise BedrockError(e)
|
||||
|
||||
# 更新缓存(如果启用)
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
@ -214,9 +291,20 @@ async def bedrock_complete_if_cache(
|
||||
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_hf_model(model_name):
|
||||
"""
|
||||
初始化Hugging Face模型和tokenizer。
|
||||
|
||||
使用指定的模型名称初始化模型和tokenizer,并根据需要设置padding token。
|
||||
|
||||
参数:
|
||||
- model_name: 模型的名称。
|
||||
|
||||
返回:
|
||||
- hf_model: 初始化的Hugging Face模型。
|
||||
- hf_tokenizer: 初始化的Hugging Face tokenizer。
|
||||
"""
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
@ -232,6 +320,21 @@ def initialize_hf_model(model_name):
|
||||
async def hf_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
使用缓存的Hugging Face模型进行推理。
|
||||
|
||||
如果缓存中存在相同的输入,则直接返回结果,否则使用指定的模型进行推理并将结果缓存。
|
||||
|
||||
参数:
|
||||
- model: 模型的名称。
|
||||
- prompt: 用户的输入提示。
|
||||
- system_prompt: 系统的提示(可选)。
|
||||
- history_messages: 历史消息列表(可选)。
|
||||
- **kwargs: 其他关键字参数,例如hashing_kv用于缓存存储。
|
||||
|
||||
返回:
|
||||
- response_text: 模型的响应文本。
|
||||
"""
|
||||
model_name = model
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
@ -297,32 +400,58 @@ async def hf_model_if_cache(
|
||||
async def ollama_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
异步函数,通过Olama模型生成回答,支持缓存机制以优化性能。
|
||||
|
||||
参数:
|
||||
model: 使用的模型名称。
|
||||
prompt: 用户的提问。
|
||||
system_prompt: 系统的提示,用于设定对话背景。
|
||||
history_messages: 历史对话消息,用于维持对话上下文。
|
||||
**kwargs: 其他参数,包括max_tokens, response_format, host, timeout等。
|
||||
|
||||
返回:
|
||||
生成的模型回答。
|
||||
"""
|
||||
# 移除不需要的参数,以符合Olama客户端的期望
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs.pop("response_format", None)
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
# 初始化Olama异步客户端
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
|
||||
|
||||
# 构建消息列表,首先添加系统提示(如果有)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 获取哈希存储实例,用于缓存
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
|
||||
# 将历史消息和当前用户提问添加到消息列表
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 如果提供了哈希存储,尝试从缓存中获取回答
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# 如果缓存中没有回答,调用Olama模型生成回答
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
|
||||
# 提取生成的回答内容
|
||||
result = response["message"]["content"]
|
||||
|
||||
# 如果使用了哈希存储,将新生成的回答存入缓存
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
|
||||
|
||||
# 返回生成的回答
|
||||
return result
|
||||
|
||||
|
||||
@ -335,8 +464,24 @@ def initialize_lmdeploy_pipeline(
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
):
|
||||
"""
|
||||
初始化lmdeploy管道,用于模型推理,带有缓存机制。
|
||||
|
||||
参数:
|
||||
model: 模型路径。
|
||||
tp: 张量并行度。
|
||||
chat_template: 聊天模板配置。
|
||||
log_level: 日志级别。
|
||||
model_format: 模型格式。
|
||||
quant_policy: 量化策略。
|
||||
|
||||
返回:
|
||||
初始化的lmdeploy管道实例。
|
||||
"""
|
||||
# 导入必要的模块和类
|
||||
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
|
||||
|
||||
# 创建并配置lmdeploy管道
|
||||
lmdeploy_pipe = pipeline(
|
||||
model_path=model,
|
||||
backend_config=TurbomindEngineConfig(
|
||||
@ -347,6 +492,7 @@ def initialize_lmdeploy_pipeline(
|
||||
else None,
|
||||
log_level="WARNING",
|
||||
)
|
||||
# 返回配置好的管道实例
|
||||
return lmdeploy_pipe
|
||||
|
||||
|
||||
@ -361,39 +507,38 @@ async def lmdeploy_model_if_cache(
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
model (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
异步执行语言模型推理,支持缓存。
|
||||
|
||||
该函数初始化 lmdeploy 管道进行模型推理,支持多种模型格式和量化策略。它处理输入的提示文本、系统提示和历史消息,
|
||||
并尝试从缓存中检索响应。如果未命中缓存,则生成响应并缓存结果以供将来使用。
|
||||
|
||||
参数:
|
||||
model (str): 模型路径。
|
||||
可以是以下选项之一:
|
||||
- i) 通过 `lmdeploy convert` 命令转换或从 ii) 和 iii) 下载的本地 turbomind 模型目录路径。
|
||||
- ii) 在 huggingface.co 上托管的 lmdeploy 量化模型的 model_id,例如
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
chat_template (str): needed when model is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
|
||||
and when the model name of local path did not match the original model name in HF.
|
||||
tp (int): tensor parallel
|
||||
prompt (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be True.
|
||||
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
|
||||
Default to be False, which means greedy decoding will be applied.
|
||||
"lmdeploy/llama2-chat-70b-4bit" 等。
|
||||
- iii) 在 huggingface.co 上托管的模型的 model_id,例如
|
||||
"internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ",
|
||||
"baichuan-inc/Baichuan2-7B-Chat" 等。
|
||||
chat_template (str): 当模型是 huggingface.co 上的 PyTorch 模型时需要,例如 "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" 等,以及当本地路径的模型名称与 HF 中的原始模型名称不匹配时。
|
||||
tp (int): 张量并行度
|
||||
prompt (Union[str, List[str]]): 要完成的输入文本。
|
||||
do_preprocess (bool): 是否预处理消息。默认为 True,表示将应用 chat_template。
|
||||
skip_special_tokens (bool): 解码时是否移除特殊标记。默认为 True。
|
||||
do_sample (bool): 是否使用采样,否则使用贪心解码。默认为 False,表示使用贪心解码。
|
||||
"""
|
||||
# 导入 lmdeploy 及相关模块,如果未安装则抛出错误
|
||||
try:
|
||||
import lmdeploy
|
||||
from lmdeploy import version_info, GenerationConfig
|
||||
except Exception:
|
||||
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
|
||||
raise ImportError("请在初始化 lmdeploy 后端之前安装 lmdeploy。")
|
||||
|
||||
# 提取并处理关键字参数
|
||||
kwargs.pop("response_format", None)
|
||||
max_new_tokens = kwargs.pop("max_tokens", 512)
|
||||
tp = kwargs.pop("tp", 1)
|
||||
@ -402,16 +547,18 @@ async def lmdeploy_model_if_cache(
|
||||
do_sample = kwargs.pop("do_sample", False)
|
||||
gen_params = kwargs
|
||||
|
||||
# 检查 lmdeploy 版本兼容性,确保支持 do_sample 参数
|
||||
version = version_info
|
||||
if do_sample is not None and version < (0, 6, 0):
|
||||
raise RuntimeError(
|
||||
"`do_sample` parameter is not supported by lmdeploy until "
|
||||
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
|
||||
"`do_sample` 参数在 lmdeploy v0.6.0 之前不受支持,当前使用的 lmdeploy 版本为 {}"
|
||||
.format(lmdeploy.__version__)
|
||||
)
|
||||
else:
|
||||
do_sample = True
|
||||
gen_params.update(do_sample=do_sample)
|
||||
|
||||
# 初始化 lmdeploy 管道
|
||||
lmdeploy_pipe = initialize_lmdeploy_pipeline(
|
||||
model=model,
|
||||
tp=tp,
|
||||
@ -421,25 +568,31 @@ async def lmdeploy_model_if_cache(
|
||||
log_level="WARNING",
|
||||
)
|
||||
|
||||
# 构建消息列表
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 获取哈希存储对象
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 尝试从缓存中获取响应
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# 配置生成参数
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**gen_params,
|
||||
)
|
||||
|
||||
# 生成响应
|
||||
response = ""
|
||||
async for res in lmdeploy_pipe.generate(
|
||||
messages,
|
||||
@ -450,14 +603,29 @@ async def lmdeploy_model_if_cache(
|
||||
):
|
||||
response += res.response
|
||||
|
||||
# 缓存生成的响应
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
使用GPT-4o模型完成文本生成任务。
|
||||
|
||||
参数:
|
||||
- prompt: 用户输入的提示文本。
|
||||
- system_prompt: 系统级别的提示文本,用于指导模型生成。
|
||||
- history_messages: 历史对话消息,用于上下文理解。
|
||||
- **kwargs: 其他可变关键字参数。
|
||||
|
||||
返回:
|
||||
- 生成的文本结果。
|
||||
"""
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
prompt,
|
||||
@ -466,10 +634,21 @@ async def gpt_4o_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
使用较小的GPT-4o模型完成文本生成任务。
|
||||
|
||||
参数:
|
||||
- prompt: 用户输入的提示文本。
|
||||
- system_prompt: 系统级别的提示文本,用于指导模型生成。
|
||||
- history_messages: 历史对话消息,用于上下文理解。
|
||||
- **kwargs: 其他可变关键字参数。
|
||||
|
||||
返回:
|
||||
- 生成的文本结果。
|
||||
"""
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
prompt,
|
||||
@ -478,10 +657,21 @@ async def gpt_4o_mini_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
使用Azure上的OpenAI模型完成文本生成任务。
|
||||
|
||||
参数:
|
||||
- prompt: 用户输入的提示文本。
|
||||
- system_prompt: 系统级别的提示文本,用于指导模型生成。
|
||||
- history_messages: 历史对话消息,用于上下文理解。
|
||||
- **kwargs: 其他可变关键字参数。
|
||||
|
||||
返回:
|
||||
- 生成的文本结果。
|
||||
"""
|
||||
return await azure_openai_complete_if_cache(
|
||||
"conversation-4o-mini",
|
||||
prompt,
|
||||
@ -490,10 +680,21 @@ async def azure_openai_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
使用Bedrock平台的特定模型完成文本生成任务。
|
||||
|
||||
参数:
|
||||
- prompt: 用户输入的提示文本。
|
||||
- system_prompt: 系统级别的提示文本,用于指导模型生成。
|
||||
- history_messages: 历史对话消息,用于上下文理解。
|
||||
- **kwargs: 其他可变关键字参数。
|
||||
|
||||
返回:
|
||||
- 生成的文本结果。
|
||||
"""
|
||||
return await bedrock_complete_if_cache(
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
prompt,
|
||||
@ -502,10 +703,21 @@ async def bedrock_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def hf_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
使用Hugging Face模型完成文本生成任务。
|
||||
|
||||
参数:
|
||||
- prompt: 用户输入的提示文本。
|
||||
- system_prompt: 系统级别的提示文本,用于指导模型生成。
|
||||
- history_messages: 历史对话消息,用于上下文理解。
|
||||
- **kwargs: 其他可变关键字参数,包括模型名称。
|
||||
|
||||
返回:
|
||||
- 生成的文本结果。
|
||||
"""
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await hf_model_if_cache(
|
||||
model_name,
|
||||
@ -515,10 +727,21 @@ async def hf_model_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
使用Ollama模型完成文本生成任务。
|
||||
|
||||
参数:
|
||||
- prompt: 用户输入的提示文本。
|
||||
- system_prompt: 系统级别的提示文本,用于指导模型生成。
|
||||
- history_messages: 历史对话消息,用于上下文理解。
|
||||
- **kwargs: 其他可变关键字参数,包括模型名称。
|
||||
|
||||
返回:
|
||||
- 生成的文本结果。
|
||||
"""
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await ollama_model_if_cache(
|
||||
model_name,
|
||||
@ -529,7 +752,9 @@ async def ollama_model_complete(
|
||||
)
|
||||
|
||||
|
||||
# 使用装饰器添加属性,如嵌入维度和最大令牌大小
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
# 使用重试机制处理可能的速率限制、API连接和超时错误
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
@ -541,6 +766,18 @@ async def openai_embedding(
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用OpenAI模型生成文本嵌入。
|
||||
|
||||
参数:
|
||||
- texts: 需要生成嵌入的文本列表
|
||||
- model: 使用的模型名称
|
||||
- base_url: API的基础URL
|
||||
- api_key: API密钥
|
||||
|
||||
返回:
|
||||
- 嵌入的NumPy数组
|
||||
"""
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
@ -552,8 +789,9 @@ async def openai_embedding(
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
# 使用装饰器添加属性,如嵌入维度和最大令牌大小
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
# 使用重试机制处理可能的速率限制、API连接和超时错误
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@ -565,6 +803,18 @@ async def azure_openai_embedding(
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用Azure OpenAI模型生成文本嵌入。
|
||||
|
||||
参数:
|
||||
- texts: 需要生成嵌入的文本列表
|
||||
- model: 使用的模型名称
|
||||
- base_url: API的基础URL
|
||||
- api_key: API密钥
|
||||
|
||||
返回:
|
||||
- 嵌入的NumPy数组
|
||||
"""
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
@ -581,7 +831,7 @@ async def azure_openai_embedding(
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
# 使用重试机制处理可能的速率限制、API连接和超时错误
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
@ -594,6 +844,19 @@ async def siliconcloud_embedding(
|
||||
max_token_size: int = 512,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用SiliconCloud模型生成文本嵌入。
|
||||
|
||||
参数:
|
||||
- texts: 需要生成嵌入的文本列表
|
||||
- model: 使用的模型名称
|
||||
- base_url: API的基础URL
|
||||
- max_token_size: 最大令牌大小
|
||||
- api_key: API密钥
|
||||
|
||||
返回:
|
||||
- 嵌入的NumPy数组
|
||||
"""
|
||||
if api_key and not api_key.startswith("Bearer "):
|
||||
api_key = "Bearer " + api_key
|
||||
|
||||
@ -633,6 +896,22 @@ async def bedrock_embedding(
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
生成给定文本的嵌入向量。
|
||||
|
||||
使用指定的模型对文本列表进行嵌入处理,支持Amazon Bedrock和Cohere模型。
|
||||
|
||||
参数:
|
||||
- texts: 需要嵌入的文本列表。
|
||||
- model: 使用的模型标识符,默认为"amazon.titan-embed-text-v2:0"。
|
||||
- aws_access_key_id: AWS访问密钥ID。
|
||||
- aws_secret_access_key: AWS秘密访问密钥。
|
||||
- aws_session_token: AWS会话令牌。
|
||||
|
||||
返回:
|
||||
- 嵌入向量的NumPy数组。
|
||||
"""
|
||||
# 设置AWS环境变量
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
@ -643,11 +922,14 @@ async def bedrock_embedding(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
# 创建aioboto3会话
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
# 根据模型提供者进行不同的处理
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
# 根据模型版本构建请求体
|
||||
if "v2" in model:
|
||||
body = json.dumps(
|
||||
{
|
||||
@ -661,6 +943,7 @@ async def bedrock_embedding(
|
||||
else:
|
||||
raise ValueError(f"Model {model} is not supported!")
|
||||
|
||||
# 调用Bedrock模型
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
@ -693,9 +976,22 @@ async def bedrock_embedding(
|
||||
|
||||
|
||||
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
"""
|
||||
使用Hugging Face模型生成给定文本的嵌入向量。
|
||||
|
||||
参数:
|
||||
- texts: 需要嵌入的文本列表。
|
||||
- tokenizer: Hugging Face的标记器实例。
|
||||
- embed_model: Hugging Face的嵌入模型实例。
|
||||
|
||||
返回:
|
||||
- 嵌入向量的NumPy数组。
|
||||
"""
|
||||
# 对文本进行标记化处理
|
||||
input_ids = tokenizer(
|
||||
texts, return_tensors="pt", padding=True, truncation=True
|
||||
).input_ids
|
||||
# 使用模型生成嵌入向量
|
||||
with torch.no_grad():
|
||||
outputs = embed_model(input_ids)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
@ -703,9 +999,22 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
|
||||
|
||||
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
"""
|
||||
使用Ollama模型生成给定文本的嵌入向量。
|
||||
|
||||
参数:
|
||||
- texts: 需要嵌入的文本列表。
|
||||
- embed_model: 使用的嵌入模型标识符。
|
||||
- **kwargs: 传递给Ollama客户端的其他参数。
|
||||
|
||||
返回:
|
||||
- 嵌入向量的列表。
|
||||
"""
|
||||
embed_text = []
|
||||
# 创建Ollama客户端实例
|
||||
ollama_client = ollama.Client(**kwargs)
|
||||
for text in texts:
|
||||
# 调用模型生成嵌入向量
|
||||
data = ollama_client.embeddings(model=embed_model, prompt=text)
|
||||
embed_text.append(data["embedding"])
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -20,27 +20,63 @@ from .base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
"""
|
||||
基于JSON文件的键值存储实现类
|
||||
继承自BaseKVStorage,提供基本的键值存储功能
|
||||
数据以JSON格式保存在文件系统中
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
"""
|
||||
初始化方法,在对象创建后自动调用
|
||||
- 设置工作目录和文件路径
|
||||
- 加载已存在的JSON数据
|
||||
"""
|
||||
working_dir = self.global_config["working_dir"] # 从全局配置获取工作目录
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") # 构建JSON文件完整路径
|
||||
self._data = load_json(self._file_name) or {} # 加载JSON文件,如果不存在则初始化为空字典
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") # 记录加载数据的数量
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
"""
|
||||
获取存储中的所有键
|
||||
返回值:
|
||||
list[str]: 包含所有键的列表
|
||||
"""
|
||||
return list(self._data.keys())
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""
|
||||
索引完成后的回调函数
|
||||
将当前内存中的数据写入JSON文件
|
||||
"""
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id):
|
||||
"""
|
||||
通过ID获取单个数据
|
||||
参数:
|
||||
id: 要查询的数据ID
|
||||
返回值:
|
||||
查找到的数据,如果不存在则返回None
|
||||
"""
|
||||
return self._data.get(id, None)
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
"""
|
||||
批量获取多个ID的数据
|
||||
参数:
|
||||
ids: ID列表
|
||||
fields: 可选,指定要返回的字段列表
|
||||
返回值:
|
||||
list: 包含查询结果的列表,每个元素对应一个ID的数据
|
||||
"""
|
||||
if fields is None:
|
||||
# 如果未指定字段,返回完整数据
|
||||
return [self._data.get(id, None) for id in ids]
|
||||
# 如果指定了字段,只返回指定的字段
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items() if k in fields}
|
||||
@ -51,38 +87,80 @@ class JsonKVStorage(BaseKVStorage):
|
||||
]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
"""
|
||||
过滤出不存在于存储中的键
|
||||
参数:
|
||||
data: 要检查的键列表
|
||||
返回值:
|
||||
set[str]: 不存在的键集合
|
||||
"""
|
||||
return set([s for s in data if s not in self._data])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
return left_data
|
||||
"""
|
||||
更新或插入数据
|
||||
参数:
|
||||
data: 要更新/插入的数据字典,格式为 {id: {字段: 值}}
|
||||
返回值:
|
||||
dict: 实际插入的新数据(不包含更新的数据)
|
||||
"""
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data} # 筛选出新数据
|
||||
self._data.update(left_data) # 更新存储
|
||||
return left_data # 返回新插入的数据
|
||||
|
||||
async def drop(self):
|
||||
"""
|
||||
清空所有数据
|
||||
将内存中的数据字典重置为空
|
||||
"""
|
||||
self._data = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
向量数据库存储实现类
|
||||
基于NanoVectorDB实现向量存储和检索功能
|
||||
支持向量的增删改查操作
|
||||
"""
|
||||
|
||||
# 余弦相似度阈值,用于过滤搜索结果
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
初始化方法,在对象创建后自动调用
|
||||
设置存储文件路径、批处理大小,并初始化向量数据库客户端
|
||||
"""
|
||||
# 构建向量数据库存储文件路径
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
# 设置批处理大小
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# 初始化向量数据库客户端
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
# 从配置中获取相似度阈值
|
||||
self.cosine_better_than_threshold = self.global_config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""
|
||||
更新或插入向量数据
|
||||
参数:
|
||||
data: 包含向量数据的字典,格式为 {id: {字段: 值}}
|
||||
返回值:
|
||||
list: 插入结果
|
||||
"""
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
|
||||
# 准备数据,提取元数据字段
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
@ -90,28 +168,49 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
|
||||
# 提取内容并分批处理
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
# 并行计算向量嵌入
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
|
||||
# 将向量添加到数据中
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
|
||||
# 执行更新/插入操作
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
"""
|
||||
查询最相似的向量
|
||||
参数:
|
||||
query: 查询文本
|
||||
top_k: 返回的最相似结果数量
|
||||
返回值:
|
||||
list: 包含相似度结果的列表
|
||||
"""
|
||||
# 计算查询文本的向量表示
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
|
||||
# 执行向量检索
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
|
||||
# 格式化返回结果
|
||||
results = [
|
||||
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
||||
]
|
||||
@ -119,12 +218,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
"""获取底层存储对象"""
|
||||
return getattr(self._client, "_NanoVectorDB__storage")
|
||||
|
||||
async def delete_entity(self, entity_name: str):
|
||||
"""
|
||||
删除指定实体
|
||||
参数:
|
||||
entity_name: 要删除的实体名称
|
||||
"""
|
||||
try:
|
||||
# 计算实体ID
|
||||
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
|
||||
|
||||
# 检查并删除实体
|
||||
if self._client.get(entity_id):
|
||||
self._client.delete(entity_id)
|
||||
logger.info(f"Entity {entity_name} have been deleted.")
|
||||
@ -134,7 +241,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error while deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_relation(self, entity_name: str):
|
||||
"""
|
||||
删除与指定实体相关的所有关系
|
||||
参数:
|
||||
entity_name: 实体名称
|
||||
"""
|
||||
try:
|
||||
# 查找所有相关关系
|
||||
relations = [
|
||||
dp
|
||||
for dp in self.client_storage["data"]
|
||||
@ -142,6 +255,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
]
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
# 执行删除操作
|
||||
if ids_to_delete:
|
||||
self._client.delete(ids_to_delete)
|
||||
logger.info(
|
||||
@ -155,19 +269,38 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""索引完成后的回调函数,保存数据到存储文件"""
|
||||
self._client.save()
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
"""
|
||||
基于NetworkX的图存储实现类
|
||||
提供图数据的存储、读取和操作功能
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
"""
|
||||
从文件加载图数据
|
||||
参数:
|
||||
file_name: GraphML文件路径
|
||||
返回值:
|
||||
nx.Graph: 加载的图对象,如果文件不存在返回None
|
||||
"""
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def write_nx_graph(graph: nx.Graph, file_name):
|
||||
"""
|
||||
将图数据写入文件
|
||||
参数:
|
||||
graph: 要保存的图对象
|
||||
file_name: 保存路径
|
||||
"""
|
||||
logger.info(
|
||||
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
||||
)
|
||||
@ -175,40 +308,51 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
@staticmethod
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
||||
"""
|
||||
获取图的最大连通分量,并确保节点和边的顺序稳定
|
||||
参考: https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
参数:
|
||||
graph: 输入图
|
||||
返回值:
|
||||
nx.Graph: 处理后的稳定图
|
||||
"""
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
# 对节点标签进行标准化处理
|
||||
node_mapping = {
|
||||
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
||||
} # type: ignore
|
||||
}
|
||||
graph = nx.relabel_nodes(graph, node_mapping)
|
||||
return NetworkXStorage._stabilize_graph(graph)
|
||||
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Ensure an undirected graph with the same relationships will always be read the same way.
|
||||
"""
|
||||
确保无向图的关系始终以相同的方式读取
|
||||
参数:
|
||||
graph: 输入图
|
||||
返回值:
|
||||
nx.Graph: 稳定化后的图
|
||||
"""
|
||||
# 根据图的类型创建新图
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
# 对节点进行排序
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
# 添加排序后的节点
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
# 对于无向图,确保边的源节点和目标节点有固定顺序
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
source, target = target, source
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
@ -216,12 +360,18 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
# 对边进行排序
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
初始化方法
|
||||
- 设置图存储文件路径
|
||||
- 加载已存在的图数据
|
||||
- 初始化节点嵌入算法
|
||||
"""
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
@ -236,46 +386,56 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""索引完成后的回调,保存图数据到文件"""
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""检查节点是否存在"""
|
||||
return self._graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""检查边是否存在"""
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""获取节点数据"""
|
||||
return self._graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""获取节点的度"""
|
||||
return self._graph.degree(node_id)
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""获取边的度(源节点度 + 目标节点度)"""
|
||||
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""获取边的数据"""
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
"""获取节点的所有边"""
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
"""更新或插入节点"""
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
"""更新或插入边"""
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete a node from the graph based on the specified node_id.
|
||||
|
||||
:param node_id: The node_id to delete
|
||||
删除指定的节点
|
||||
参数:
|
||||
node_id: 要删除的节点ID
|
||||
"""
|
||||
if self._graph.has_node(node_id):
|
||||
self._graph.remove_node(node_id)
|
||||
@ -284,12 +444,23 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
"""
|
||||
使用指定算法进行节点嵌入
|
||||
参数:
|
||||
algorithm: 嵌入算法名称
|
||||
返回值:
|
||||
tuple: (嵌入向量数组, 节点ID列表)
|
||||
"""
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
# @TODO: NOT USED
|
||||
async def _node2vec_embed(self):
|
||||
"""
|
||||
使用node2vec算法进行节点嵌入(未使用)
|
||||
返回值:
|
||||
tuple: (嵌入向量数组, 节点ID列表)
|
||||
"""
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
|
Loading…
Reference in New Issue
Block a user