lightrag.py使用通义灵码加注释。 其他三个文件使用cursor加注释。

This commit is contained in:
many2many 2024-11-16 11:29:02 +08:00
parent c0fa4da53d
commit c8ee7286cb
4 changed files with 1261 additions and 260 deletions

View File

@ -1,3 +1,8 @@
"""
LightRAG - 轻量级检索增强生成系统
该模块实现了一个基于图的文档检索和问答系统支持文档的存储检索和知识图谱构建
"""
import asyncio import asyncio
import os import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
@ -5,144 +10,170 @@ from datetime import datetime
from functools import partial from functools import partial
from typing import Type, cast from typing import Type, cast
# 导入LLM相关功能
from .llm import ( from .llm import (
gpt_4o_mini_complete, gpt_4o_mini_complete, # GPT模型完成功能
openai_embedding, openai_embedding, # OpenAI文本嵌入功能
) )
# 导入核心操作功能
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size, # 文本分块
extract_entities, extract_entities, # 实体提取
local_query, local_query, # 本地查询
global_query, global_query, # 全局查询
hybrid_query, hybrid_query, # 混合查询
naive_query, naive_query, # 简单查询
) )
# 导入存储实现
from .storage import ( from .storage import (
JsonKVStorage, JsonKVStorage, # JSON键值存储
NanoVectorDBStorage, NanoVectorDBStorage, # 向量数据库存储
NetworkXStorage, NetworkXStorage, # 图数据库存储
) )
from .kg.neo4j_impl import Neo4JStorage from .kg.neo4j_impl import Neo4JStorage # Neo4j图数据库实现
# future KG integrations # 未来可能的图数据库集成
# from .kg.ArangoDB_impl import ( # from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage # GraphStorage as ArangoDBStorage
# ) # )
# 导入工具函数
from .utils import ( from .utils import (
EmbeddingFunc, EmbeddingFunc, # 嵌入函数类型
compute_mdhash_id, compute_mdhash_id, # 计算MD5哈希ID
limit_async_func_call, limit_async_func_call, # 限制异步函数调用
convert_response_to_json, convert_response_to_json, # 响应转JSON
logger, logger, # 日志记录器
set_logger, set_logger, # 设置日志
) )
# 导入基类
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage, # 图存储基类
BaseKVStorage, BaseKVStorage, # 键值存储基类
BaseVectorStorage, BaseVectorStorage, # 向量存储基类
StorageNameSpace, StorageNameSpace, # 存储命名空间
QueryParam, QueryParam, # 查询参数
) )
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
获取或创建事件循环
如果当前线程没有事件循环则创建一个新的
返回值:
asyncio.AbstractEventLoop: 事件循环实例
"""
try: try:
return asyncio.get_event_loop() return asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.info("Creating a new event loop in main thread.") logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
return loop return loop
@dataclass @dataclass
class LightRAG: class LightRAG:
"""
轻量级检索增强生成(LightRAG)系统的主类
实现了文档的存储检索知识图谱构建和问答功能
"""
# 工作目录配置,用存储所有缓存文件
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
# 知识图谱存储类型默认使用NetworkX实现
kg: str = field(default="NetworkXStorage") kg: str = field(default="NetworkXStorage")
# 日志级别设置
current_log_level = logger.level current_log_level = logger.level
log_level: str = field(default=current_log_level) log_level: str = field(default=current_log_level)
# text chunking # 文本分块参数配置
chunk_token_size: int = 1200 chunk_token_size: int = 1200 # 每个文本块的目标token数
chunk_overlap_token_size: int = 100 chunk_overlap_token_size: int = 100 # 相邻文本块的重叠token数
tiktoken_model_name: str = "gpt-4o-mini" tiktoken_model_name: str = "gpt-4o-mini" # 用于计算token的模型名称
# entity extraction # 实体提取参数
entity_extract_max_gleaning: int = 1 entity_extract_max_gleaning: int = 1 # 最大实体提取次数
entity_summary_to_max_tokens: int = 500 entity_summary_to_max_tokens: int = 500 # 实体摘要的最大token数
# node embedding # 节点嵌入配置
node_embedding_algorithm: str = "node2vec" node_embedding_algorithm: str = "node2vec" # 节点嵌入算法选择
node2vec_params: dict = field( node2vec_params: dict = field(
default_factory=lambda: { default_factory=lambda: {
"dimensions": 1536, "dimensions": 1536, # 嵌入向量维度
"num_walks": 10, "num_walks": 10, # 每个节点的随机游走次数
"walk_length": 40, "walk_length": 40, # 每次随机游走的长度
"window_size": 2, "window_size": 2, # 上下文窗口大小
"iterations": 3, "iterations": 3, # 训练迭代次数
"random_seed": 3, "random_seed": 3, # 随机种子
} }
) )
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) # 文本嵌入配置
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) # 默认使用OpenAI的嵌入模型
embedding_batch_num: int = 32 embedding_batch_num: int = 32 # 批处理大小
embedding_func_max_async: int = 16 embedding_func_max_async: int = 16 # 最大并发请求数
# LLM # 语言模型配置
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# llm_model_func: callable = gpt_4o_mini_complete # 默认使用的语言模型
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 模型名称
llm_model_max_token_size: int = 32768 llm_model_max_token_size: int = 32768 # 模型最大token限制
llm_model_max_async: int = 16 llm_model_max_async: int = 16 # 最大并发请求数
llm_model_kwargs: dict = field(default_factory=dict) llm_model_kwargs: dict = field(default_factory=dict) # 模型额外参数
# storage # 存储配置
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage # 键值存储类
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage # 向量存储类
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) vector_db_storage_cls_kwargs: dict = field(default_factory=dict) # 向量存储额外参数
enable_llm_cache: bool = True enable_llm_cache: bool = True # 是否启用语言模型缓存
# extension # 扩展配置
addon_params: dict = field(default_factory=dict) addon_params: dict = field(default_factory=dict) # 附加参数
convert_response_to_json_func: callable = convert_response_to_json convert_response_to_json_func: callable = convert_response_to_json # JSON转换函数
def __post_init__(self): def __post_init__(self):
"""
初始化方法在对象创建后自动调用
负责设置日志初始化存储系统和配置各种功能组件
"""
# 配置日志系统
log_file = os.path.join(self.working_dir, "lightrag.log") log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file) set_logger(log_file)
logger.setLevel(self.log_level) logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
# 记录初始化参数
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# @TODO: should move all storage setup here to leverage initial start params attached to self. # 根据配置选择图存储实现类
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.kg self.kg
] ]
# 确保工作目录存在
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
# 初始化文档存储系统
self.full_docs = self.key_string_value_json_storage_cls( self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self) namespace="full_docs", global_config=asdict(self)
) )
# 初始化文本块存储系统
self.text_chunks = self.key_string_value_json_storage_cls( self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self) namespace="text_chunks", global_config=asdict(self)
) )
# 初始化语言模型响应缓存(如果启用)
self.llm_response_cache = ( self.llm_response_cache = (
self.key_string_value_json_storage_cls( self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self) namespace="llm_response_cache", global_config=asdict(self)
@ -150,32 +181,40 @@ class LightRAG:
if self.enable_llm_cache if self.enable_llm_cache
else None else None
) )
# 初始化实体关系图存储
self.chunk_entity_relation_graph = self.graph_storage_cls( self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self) namespace="chunk_entity_relation", global_config=asdict(self)
) )
# 配置嵌入函数的并发限制
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func self.embedding_func
) )
# 初始化向量数据库存储系统
# 用于存储实体的向量表示
self.entities_vdb = self.vector_db_storage_cls( self.entities_vdb = self.vector_db_storage_cls(
namespace="entities", namespace="entities",
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"entity_name"}, meta_fields={"entity_name"},
) )
# 用于存储关系的向量表示
self.relationships_vdb = self.vector_db_storage_cls( self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships", namespace="relationships",
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}, meta_fields={"src_id", "tgt_id"},
) )
# 用于存储文本块的向量表示
self.chunks_vdb = self.vector_db_storage_cls( self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks", namespace="chunks",
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
# 配置语言模型函数的并发限制和缓存
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
self.llm_model_func, self.llm_model_func,
@ -185,33 +224,62 @@ class LightRAG:
) )
def _get_storage_class(self) -> Type[BaseGraphStorage]: def _get_storage_class(self) -> Type[BaseGraphStorage]:
"""
获取图存储类的实现
根据配置选择合适的图存储后端Neo4J或NetworkX
返回值:
Type[BaseGraphStorage]: 图存储类
"""
return { return {
"Neo4JStorage": Neo4JStorage, "Neo4JStorage": Neo4JStorage,
"NetworkXStorage": NetworkXStorage, "NetworkXStorage": NetworkXStorage,
} }
def insert(self, string_or_strings): def insert(self, string_or_strings):
"""
同步方式插入文档
将字符串或字符串列表插入到系统中进行处理
参数:
string_or_strings: 单个字符串或字符串列表表示要处理的文档内容
"""
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings)) return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings): async def ainsert(self, string_or_strings):
"""
异步方式插入文档
处理文档内容包括分块实体提取和向量化存储
参数:
string_or_strings: 单个字符串或字符串列表表示要处理的文档内容
"""
try: try:
# 确保输入是列表形式
if isinstance(string_or_strings, str): if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings] string_or_strings = [string_or_strings]
# 为每个文档生成唯一ID并创建文档字典
new_docs = { new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings for c in string_or_strings
} }
# 过滤掉已存在的文档
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs): if not len(new_docs):
logger.warning("All docs are already in the storage") logger.warning("All docs are already in the storage")
return return
logger.info(f"[New Docs] inserting {len(new_docs)} docs") logger.info(f"[New Docs] inserting {len(new_docs)} docs")
# 处理文档分块
inserting_chunks = {} inserting_chunks = {}
for doc_key, doc in new_docs.items(): for doc_key, doc in new_docs.items():
# 对每个文档进行分块并为每个块生成唯一ID
chunks = { chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): { compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp, **dp,
@ -225,19 +293,25 @@ class LightRAG:
) )
} }
inserting_chunks.update(chunks) inserting_chunks.update(chunks)
# 过滤掉已存在的文本块
_add_chunk_keys = await self.text_chunks.filter_keys( _add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys()) list(inserting_chunks.keys())
) )
inserting_chunks = { inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
} }
if not len(inserting_chunks): if not len(inserting_chunks):
logger.warning("All chunks are already in the storage") logger.warning("All chunks are already in the storage")
return return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
# 将新的文本块插入向量数据库
await self.chunks_vdb.upsert(inserting_chunks) await self.chunks_vdb.upsert(inserting_chunks)
# 提取实体和关系并更新知识图谱
logger.info("[Entity Extraction]...") logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities( maybe_new_kg = await extract_entities(
inserting_chunks, inserting_chunks,
@ -246,38 +320,70 @@ class LightRAG:
relationships_vdb=self.relationships_vdb, relationships_vdb=self.relationships_vdb,
global_config=asdict(self), global_config=asdict(self),
) )
if maybe_new_kg is None: if maybe_new_kg is None:
logger.warning("No new entities and relationships found") logger.warning("No new entities and relationships found")
return return
self.chunk_entity_relation_graph = maybe_new_kg self.chunk_entity_relation_graph = maybe_new_kg
# 更新文档和文本块存储
await self.full_docs.upsert(new_docs) await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks) await self.text_chunks.upsert(inserting_chunks)
finally: finally:
# 完成插入后执行清理工作
await self._insert_done() await self._insert_done()
async def _insert_done(self): async def _insert_done(self):
"""
插入操作完成后的回调函数
负责更新所有存储实例的索引并保存状态
"""
tasks = [] tasks = []
# 遍历所有需要执行回调的存储实例
for storage_inst in [ for storage_inst in [
self.full_docs, self.full_docs, # 完整文档存储
self.text_chunks, self.text_chunks, # 文本块存储
self.llm_response_cache, self.llm_response_cache, # LLM响应缓存
self.entities_vdb, self.entities_vdb, # 实体向量数据库
self.relationships_vdb, self.relationships_vdb, # 关系向量数据库
self.chunks_vdb, self.chunks_vdb, # 文本块向量<E59091><E9878F><EFBFBD>据库
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph, # 实体关系图
]: ]:
if storage_inst is None: if storage_inst is None:
continue continue
# 将每个存储实例的回调任务添加到任务列表
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def query(self, query: str, param: QueryParam = QueryParam()): def query(self, query: str, param: QueryParam = QueryParam()):
"""
同步方式执行查询
参数:
query: 查询文本
param: 查询参数配置
返回值:
查询结果
"""
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param)) return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()): async def aquery(self, query: str, param: QueryParam = QueryParam()):
"""
异步方式执行查询
支持多种查询模式本地查询全局查询混合查询和简单查询
参数:
query: 查询文本
param: 查询参数配置
返回值:
查询结果
"""
# 根据查询模式选择相应的查询方法
if param.mode == "local": if param.mode == "local":
# 本地查询:主要基于局部上下文
response = await local_query( response = await local_query(
query, query,
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -288,6 +394,7 @@ class LightRAG:
asdict(self), asdict(self),
) )
elif param.mode == "global": elif param.mode == "global":
# 全局查询:考虑整个知识图谱
response = await global_query( response = await global_query(
query, query,
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -298,6 +405,7 @@ class LightRAG:
asdict(self), asdict(self),
) )
elif param.mode == "hybrid": elif param.mode == "hybrid":
# 混合查询:结合局部和全局信息
response = await hybrid_query( response = await hybrid_query(
query, query,
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
@ -308,6 +416,7 @@ class LightRAG:
asdict(self), asdict(self),
) )
elif param.mode == "naive": elif param.mode == "naive":
# 简单查询:直接基于文本相似度
response = await naive_query( response = await naive_query(
query, query,
self.chunks_vdb, self.chunks_vdb,
@ -317,38 +426,73 @@ class LightRAG:
) )
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
# 执行查询完成后的清理工作
await self._query_done() await self._query_done()
return response return response
async def _query_done(self): async def _query_done(self):
"""
查询操作完成后的回调函数
主要用于更新LLM响应缓存的状态
"""
tasks = [] tasks = []
# 目前只需要处理LLM响应缓存的回调
for storage_inst in [self.llm_response_cache]: for storage_inst in [self.llm_response_cache]:
if storage_inst is None: if storage_inst is None:
continue continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def delete_by_entity(self, entity_name: str): def delete_by_entity(self, entity_name: str):
"""
同步方式删除指定实体
参数:
entity_name: 要删除的实体名称
"""
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name)) return loop.run_until_complete(self.adelete_by_entity(entity_name))
async def adelete_by_entity(self, entity_name: str): async def adelete_by_entity(self, entity_name: str):
entity_name = f'"{entity_name.upper()}"' """
异步方式删除指定实体及其相关的所有信息
参数:
entity_name: 要删除的实体名称
"""
# 标准化实体名称(转为大写并添加引号)
entity_name = f'"{entity_name.upper()}"'
try: try:
# 依次删除实体在各个存储中的数据:
# 1. 从实体向量数据库中删除
await self.entities_vdb.delete_entity(entity_name) await self.entities_vdb.delete_entity(entity_name)
# 2. 从关系向量数据库中删除相关关系
await self.relationships_vdb.delete_relation(entity_name) await self.relationships_vdb.delete_relation(entity_name)
# 3. 从知识图谱中删除节点
await self.chunk_entity_relation_graph.delete_node(entity_name) await self.chunk_entity_relation_graph.delete_node(entity_name)
# 记录删除成功的日志
logger.info( logger.info(
f"Entity '{entity_name}' and its relationships have been deleted." f"Entity '{entity_name}' and its relationships have been deleted."
) )
# 执行删除完成后的清理工作
await self._delete_by_entity_done() await self._delete_by_entity_done()
except Exception as e: except Exception as e:
# 记录删除过程中的错误
logger.error(f"Error while deleting entity '{entity_name}': {e}") logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(self): async def _delete_by_entity_done(self):
"""
实体删除操作完成后的回调函数
负责更新所有相关存储实例的状态
"""
tasks = [] tasks = []
# 遍历需要更新的存储实例:
# - 实体向量数据库
# - 关系向量数据库
# - 实体关系图
for storage_inst in [ for storage_inst in [
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
@ -356,5 +500,11 @@ class LightRAG:
]: ]:
if storage_inst is None: if storage_inst is None:
continue continue
# 将每个存储实例的回调添加到任务列表
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
# 并发执行所有回调任务
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@ -31,9 +31,10 @@ from typing import List, Dict, Callable, Any
from .base import BaseKVStorage from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs from .utils import compute_args_hash, wrap_embedding_func_with_attrs
# 禁用并行化以避免tokenizers的并行化导致的问题
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 使用retry装饰器处理重试逻辑处理OpenAI API的速率限制、连接和超时错误
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@ -48,35 +49,64 @@ async def openai_complete_if_cache(
api_key=None, api_key=None,
**kwargs, **kwargs,
) -> str: ) -> str:
"""
异步函数通过OpenAI的API获取语言模型的补全结果支持缓存机制
参数:
- model: 使用的模型名称
- prompt: 用户输入的提示
- system_prompt: 系统提示可选
- history_messages: 历史消息可选
- base_url: API的基础URL可选
- api_key: API密钥可选
- **kwargs: 其他参数
返回:
- str: 模型生成的文本
"""
# 设置环境变量中的API密钥
if api_key: if api_key:
os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_API_KEY"] = api_key
# 初始化OpenAI异步客户端
openai_async_client = ( openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
) )
# 初始化哈希存储和消息列表
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
# 添加系统提示到消息列表
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
# 将历史消息和当前提示添加到消息列表
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# 检查缓存中是否有结果
if hashing_kv is not None: if hashing_kv is not None:
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
# 调用OpenAI API获取补全结果
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
# 将结果缓存
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert( await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}} {args_hash: {"return": response.choices[0].message.content, "model": model}}
) )
# 返回生成的文本
return response.choices[0].message.content return response.choices[0].message.content
# 与openai_complete_if_cache类似的函数但用于Azure OpenAI服务
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@ -91,45 +121,71 @@ async def azure_openai_complete_if_cache(
api_key=None, api_key=None,
**kwargs, **kwargs,
): ):
"""
异步函数通过Azure OpenAI的API获取语言模型的补全结果支持缓存机制
参数:
- model: 使用的模型名称
- prompt: 用户输入的提示
- system_prompt: 系统提示可选
- history_messages: 历史消息可选
- base_url: API的基础URL可选
- api_key: API密钥可选
- **kwargs: 其他参数
返回:
- str: 模型生成的文本
"""
# 设置环境变量中的API密钥和端点
if api_key: if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url: if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
# 初始化Azure OpenAI异步客户端
openai_async_client = AsyncAzureOpenAI( openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
) )
# 初始化哈希存储和消息列表
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
# 添加系统提示到消息列表
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
# 将历史消息和当前提示添加到消息列表
messages.extend(history_messages) messages.extend(history_messages)
if prompt is not None: if prompt is not None:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# 检查缓存中是否有结果
if hashing_kv is not None: if hashing_kv is not None:
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
# 调用Azure OpenAI API获取补全结果
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
# 将结果缓存
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert( await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}} {args_hash: {"return": response.choices[0].message.content, "model": model}}
) )
# 返回生成的文本
return response.choices[0].message.content return response.choices[0].message.content
class BedrockError(Exception): class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock""" """Amazon Bedrock 相关问题的通用错误"""
@retry( @retry(
stop=stop_after_attempt(5), stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60), wait=wait_exponential(multiplier=1, max=60),
@ -145,6 +201,25 @@ async def bedrock_complete_if_cache(
aws_session_token=None, aws_session_token=None,
**kwargs, **kwargs,
) -> str: ) -> str:
"""
异步使用 Amazon Bedrock 完成文本生成支持缓存
如果缓存命中则直接返回缓存结果该函数在失败时支持重试
参数:
- model: 要使用的 Bedrock 模型的模型 ID
- prompt: 用户输入的提示
- system_prompt: 系统提示如果有
- history_messages: 会话历史消息列表用于对话上下文
- aws_access_key_id: AWS 访问密钥 ID
- aws_secret_access_key: AWS 秘密访问密钥
- aws_session_token: AWS 会话令牌
- **kwargs: 其他参数例如推理参数
返回:
- str: 生成的文本结果
"""
# 设置 AWS 凭证
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id "AWS_ACCESS_KEY_ID", aws_access_key_id
) )
@ -155,24 +230,24 @@ async def bedrock_complete_if_cache(
"AWS_SESSION_TOKEN", aws_session_token "AWS_SESSION_TOKEN", aws_session_token
) )
# Fix message history format # 修复消息历史记录格式
messages = [] messages = []
for history_message in history_messages: for history_message in history_messages:
message = copy.copy(history_message) message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}] message["content"] = [{"text": message["content"]}]
messages.append(message) messages.append(message)
# Add user prompt # 添加用户提示
messages.append({"role": "user", "content": [{"text": prompt}]}) messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments # 初始化 Converse API 参数
args = {"modelId": model, "messages": messages} args = {"modelId": model, "messages": messages}
# Define system prompt # 定义系统提示
if system_prompt: if system_prompt:
args["system"] = [{"text": system_prompt}] args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters # 映射并设置推理参数
inference_params_map = { inference_params_map = {
"max_tokens": "maxTokens", "max_tokens": "maxTokens",
"top_p": "topP", "top_p": "topP",
@ -187,6 +262,7 @@ async def bedrock_complete_if_cache(
kwargs.pop(param) kwargs.pop(param)
) )
# 处理缓存逻辑
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None: if hashing_kv is not None:
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
@ -194,7 +270,7 @@ async def bedrock_complete_if_cache(
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
# Call model via Converse API # 通过 Converse API 调用模型
session = aioboto3.Session() session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client: async with session.client("bedrock-runtime") as bedrock_async_client:
try: try:
@ -202,6 +278,7 @@ async def bedrock_complete_if_cache(
except Exception as e: except Exception as e:
raise BedrockError(e) raise BedrockError(e)
# 更新缓存(如果启用)
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert( await hashing_kv.upsert(
{ {
@ -214,9 +291,20 @@ async def bedrock_complete_if_cache(
return response["output"]["message"]["content"][0]["text"] return response["output"]["message"]["content"][0]["text"]
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def initialize_hf_model(model_name): def initialize_hf_model(model_name):
"""
初始化Hugging Face模型和tokenizer
使用指定的模型名称初始化模型和tokenizer并根据需要设置padding token
参数:
- model_name: 模型的名称
返回:
- hf_model: 初始化的Hugging Face模型
- hf_tokenizer: 初始化的Hugging Face tokenizer
"""
hf_tokenizer = AutoTokenizer.from_pretrained( hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True model_name, device_map="auto", trust_remote_code=True
) )
@ -232,6 +320,21 @@ def initialize_hf_model(model_name):
async def hf_model_if_cache( async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
使用缓存的Hugging Face模型进行推理
如果缓存中存在相同的输入则直接返回结果否则使用指定的模型进行推理并将结果缓存
参数:
- model: 模型的名称
- prompt: 用户的输入提示
- system_prompt: 系统的提示可选
- history_messages: 历史消息列表可选
- **kwargs: 其他关键字参数例如hashing_kv用于缓存存储
返回:
- response_text: 模型的响应文本
"""
model_name = model model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name) hf_model, hf_tokenizer = initialize_hf_model(model_name)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
@ -297,32 +400,58 @@ async def hf_model_if_cache(
async def ollama_model_if_cache( async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
异步函数通过Olama模型生成回答支持缓存机制以优化性能
参数:
model: 使用的模型名称
prompt: 用户的提问
system_prompt: 系统的提示用于设定对话背景
history_messages: 历史对话消息用于维持对话上下文
**kwargs: 其他参数包括max_tokens, response_format, host, timeout等
返回:
生成的模型回答
"""
# 移除不需要的参数以符合Olama客户端的期望
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None) kwargs.pop("response_format", None)
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
# 初始化Olama异步客户端
ollama_client = ollama.AsyncClient(host=host, timeout=timeout) ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
# 构建消息列表,首先添加系统提示(如果有)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
# 获取哈希存储实例,用于缓存
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# 将历史消息和当前用户提问添加到消息列表
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# 如果提供了哈希存储,尝试从缓存中获取回答
if hashing_kv is not None: if hashing_kv is not None:
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
# 如果缓存中没有回答调用Olama模型生成回答
response = await ollama_client.chat(model=model, messages=messages, **kwargs) response = await ollama_client.chat(model=model, messages=messages, **kwargs)
# 提取生成的回答内容
result = response["message"]["content"] result = response["message"]["content"]
# 如果使用了哈希存储,将新生成的回答存入缓存
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": result, "model": model}}) await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
# 返回生成的回答
return result return result
@ -335,8 +464,24 @@ def initialize_lmdeploy_pipeline(
model_format="hf", model_format="hf",
quant_policy=0, quant_policy=0,
): ):
"""
初始化lmdeploy管道用于模型推理带有缓存机制
参数:
model: 模型路径
tp: 张量并行度
chat_template: 聊天模板配置
log_level: 日志级别
model_format: 模型格式
quant_policy: 量化策略
返回:
初始化的lmdeploy管道实例
"""
# 导入必要的模块和类
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
# 创建并配置lmdeploy管道
lmdeploy_pipe = pipeline( lmdeploy_pipe = pipeline(
model_path=model, model_path=model,
backend_config=TurbomindEngineConfig( backend_config=TurbomindEngineConfig(
@ -347,6 +492,7 @@ def initialize_lmdeploy_pipeline(
else None, else None,
log_level="WARNING", log_level="WARNING",
) )
# 返回配置好的管道实例
return lmdeploy_pipe return lmdeploy_pipe
@ -361,39 +507,38 @@ async def lmdeploy_model_if_cache(
**kwargs, **kwargs,
) -> str: ) -> str:
""" """
Args: 异步执行语言模型推理支持缓存
model (str): The path to the model.
It could be one of the following options: 该函数初始化 lmdeploy 管道进行模型推理支持多种模型格式和量化策略它处理输入的提示文本系统提示和历史消息
- i) A local directory path of a turbomind model which is 并尝试从缓存中检索响应如果未命中缓存则生成响应并缓存结果以供将来使用
converted by `lmdeploy convert` command or download
from ii) and iii). 参数:
- ii) The model_id of a lmdeploy-quantized model hosted model (str): 模型路径
inside a model repo on huggingface.co, such as 可以是以下选项之一
- i) 通过 `lmdeploy convert` 命令转换或从 ii) iii) 下载的本地 turbomind 模型目录路径
- ii) huggingface.co 上托管的 lmdeploy 量化模型的 model_id例如
"InternLM/internlm-chat-20b-4bit", "InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc. "lmdeploy/llama2-chat-70b-4bit"
- iii) The model_id of a model hosted inside a model repo - iii) huggingface.co 上托管的模型的 model_id例如
on huggingface.co, such as "internlm/internlm-chat-7b", "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" "Qwen/Qwen-7B-Chat ",
and so on. "baichuan-inc/Baichuan2-7B-Chat"
chat_template (str): needed when model is a pytorch model on chat_template (str): 当模型是 huggingface.co 上的 PyTorch 模型时需要例如 "internlm-chat-7b",
huggingface.co, such as "internlm-chat-7b", "Qwen-7B-Chat ", "Baichuan2-7B-Chat" 以及当本地路径的模型名称与 HF 中的原始模型名称不匹配时
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, tp (int): 张量并行度
and when the model name of local path did not match the original model name in HF. prompt (Union[str, List[str]]): 要完成的输入文本
tp (int): tensor parallel do_preprocess (bool): 是否预处理消息默认为 True表示将应用 chat_template
prompt (Union[str, List[str]]): input texts to be completed. skip_special_tokens (bool): 解码时是否移除特殊标记默认为 True
do_preprocess (bool): whether pre-process the messages. Default to do_sample (bool): 是否使用采样否则使用贪心解码默认为 False表示使用贪心解码
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
""" """
# 导入 lmdeploy 及相关模块,如果未安装则抛出错误
try: try:
import lmdeploy import lmdeploy
from lmdeploy import version_info, GenerationConfig from lmdeploy import version_info, GenerationConfig
except Exception: except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") raise ImportError("请在初始化 lmdeploy 后端之前安装 lmdeploy。")
# 提取并处理关键字参数
kwargs.pop("response_format", None) kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512) max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1) tp = kwargs.pop("tp", 1)
@ -402,16 +547,18 @@ async def lmdeploy_model_if_cache(
do_sample = kwargs.pop("do_sample", False) do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs gen_params = kwargs
# 检查 lmdeploy 版本兼容性,确保支持 do_sample 参数
version = version_info version = version_info
if do_sample is not None and version < (0, 6, 0): if do_sample is not None and version < (0, 6, 0):
raise RuntimeError( raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until " "`do_sample` 参数在 lmdeploy v0.6.0 之前不受支持,当前使用的 lmdeploy 版本为 {}"
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" .format(lmdeploy.__version__)
) )
else: else:
do_sample = True do_sample = True
gen_params.update(do_sample=do_sample) gen_params.update(do_sample=do_sample)
# 初始化 lmdeploy 管道
lmdeploy_pipe = initialize_lmdeploy_pipeline( lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model, model=model,
tp=tp, tp=tp,
@ -421,25 +568,31 @@ async def lmdeploy_model_if_cache(
log_level="WARNING", log_level="WARNING",
) )
# 构建消息列表
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
# 获取哈希存储对象
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# 尝试从缓存中获取响应
if hashing_kv is not None: if hashing_kv is not None:
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
# 配置生成参数
gen_config = GenerationConfig( gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
**gen_params, **gen_params,
) )
# 生成响应
response = "" response = ""
async for res in lmdeploy_pipe.generate( async for res in lmdeploy_pipe.generate(
messages, messages,
@ -450,14 +603,29 @@ async def lmdeploy_model_if_cache(
): ):
response += res.response response += res.response
# 缓存生成的响应
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
return response return response
async def gpt_4o_complete( async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
使用GPT-4o模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o", "gpt-4o",
prompt, prompt,
@ -466,10 +634,21 @@ async def gpt_4o_complete(
**kwargs, **kwargs,
) )
async def gpt_4o_mini_complete( async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
使用较小的GPT-4o模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o-mini", "gpt-4o-mini",
prompt, prompt,
@ -478,10 +657,21 @@ async def gpt_4o_mini_complete(
**kwargs, **kwargs,
) )
async def azure_openai_complete( async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
使用Azure上的OpenAI模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await azure_openai_complete_if_cache( return await azure_openai_complete_if_cache(
"conversation-4o-mini", "conversation-4o-mini",
prompt, prompt,
@ -490,10 +680,21 @@ async def azure_openai_complete(
**kwargs, **kwargs,
) )
async def bedrock_complete( async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
使用Bedrock平台的特定模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数
返回:
- 生成的文本结果
"""
return await bedrock_complete_if_cache( return await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-haiku-20240307-v1:0",
prompt, prompt,
@ -502,10 +703,21 @@ async def bedrock_complete(
**kwargs, **kwargs,
) )
async def hf_model_complete( async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
使用Hugging Face模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数包括模型名称
返回:
- 生成的文本结果
"""
model_name = kwargs["hashing_kv"].global_config["llm_model_name"] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache( return await hf_model_if_cache(
model_name, model_name,
@ -515,10 +727,21 @@ async def hf_model_complete(
**kwargs, **kwargs,
) )
async def ollama_model_complete( async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
"""
使用Ollama模型完成文本生成任务
参数:
- prompt: 用户输入的提示文本
- system_prompt: 系统级别的提示文本用于指导模型生成
- history_messages: 历史对话消息用于上下文理解
- **kwargs: 其他可变关键字参数包括模型名称
返回:
- 生成的文本结果
"""
model_name = kwargs["hashing_kv"].global_config["llm_model_name"] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache( return await ollama_model_if_cache(
model_name, model_name,
@ -529,7 +752,9 @@ async def ollama_model_complete(
) )
# 使用装饰器添加属性,如嵌入维度和最大令牌大小
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
# 使用重试机制处理可能的速率限制、API连接和超时错误
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60), wait=wait_exponential(multiplier=1, min=4, max=60),
@ -541,6 +766,18 @@ async def openai_embedding(
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
) -> np.ndarray: ) -> np.ndarray:
"""
使用OpenAI模型生成文本嵌入
参数:
- texts: 需要生成嵌入的文本列表
- model: 使用的模型名称
- base_url: API的基础URL
- api_key: API密钥
返回:
- 嵌入的NumPy数组
"""
if api_key: if api_key:
os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_API_KEY"] = api_key
@ -552,8 +789,9 @@ async def openai_embedding(
) )
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
# 使用装饰器添加属性,如嵌入维度和最大令牌大小
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
# 使用重试机制处理可能的速率限制、API连接和超时错误
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@ -565,6 +803,18 @@ async def azure_openai_embedding(
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
) -> np.ndarray: ) -> np.ndarray:
"""
使用Azure OpenAI模型生成文本嵌入
参数:
- texts: 需要生成嵌入的文本列表
- model: 使用的模型名称
- base_url: API的基础URL
- api_key: API密钥
返回:
- 嵌入的NumPy数组
"""
if api_key: if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url: if base_url:
@ -581,7 +831,7 @@ async def azure_openai_embedding(
) )
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
# 使用重试机制处理可能的速率限制、API连接和超时错误
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60), wait=wait_exponential(multiplier=1, min=4, max=60),
@ -594,6 +844,19 @@ async def siliconcloud_embedding(
max_token_size: int = 512, max_token_size: int = 512,
api_key: str = None, api_key: str = None,
) -> np.ndarray: ) -> np.ndarray:
"""
使用SiliconCloud模型生成文本嵌入
参数:
- texts: 需要生成嵌入的文本列表
- model: 使用的模型名称
- base_url: API的基础URL
- max_token_size: 最大令牌大小
- api_key: API密钥
返回:
- 嵌入的NumPy数组
"""
if api_key and not api_key.startswith("Bearer "): if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key api_key = "Bearer " + api_key
@ -633,6 +896,22 @@ async def bedrock_embedding(
aws_secret_access_key=None, aws_secret_access_key=None,
aws_session_token=None, aws_session_token=None,
) -> np.ndarray: ) -> np.ndarray:
"""
生成给定文本的嵌入向量
使用指定的模型对文本列表进行嵌入处理支持Amazon Bedrock和Cohere模型
参数:
- texts: 需要嵌入的文本列表
- model: 使用的模型标识符默认为"amazon.titan-embed-text-v2:0"
- aws_access_key_id: AWS访问密钥ID
- aws_secret_access_key: AWS秘密访问密钥
- aws_session_token: AWS会话令牌
返回:
- 嵌入向量的NumPy数组
"""
# 设置AWS环境变量
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id "AWS_ACCESS_KEY_ID", aws_access_key_id
) )
@ -643,11 +922,14 @@ async def bedrock_embedding(
"AWS_SESSION_TOKEN", aws_session_token "AWS_SESSION_TOKEN", aws_session_token
) )
# 创建aioboto3会话
session = aioboto3.Session() session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client: async with session.client("bedrock-runtime") as bedrock_async_client:
# 根据模型提供者进行不同的处理
if (model_provider := model.split(".")[0]) == "amazon": if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = [] embed_texts = []
for text in texts: for text in texts:
# 根据模型版本构建请求体
if "v2" in model: if "v2" in model:
body = json.dumps( body = json.dumps(
{ {
@ -661,6 +943,7 @@ async def bedrock_embedding(
else: else:
raise ValueError(f"Model {model} is not supported!") raise ValueError(f"Model {model} is not supported!")
# 调用Bedrock模型
response = await bedrock_async_client.invoke_model( response = await bedrock_async_client.invoke_model(
modelId=model, modelId=model,
body=body, body=body,
@ -693,9 +976,22 @@ async def bedrock_embedding(
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
"""
使用Hugging Face模型生成给定文本的嵌入向量
参数:
- texts: 需要嵌入的文本列表
- tokenizer: Hugging Face的标记器实例
- embed_model: Hugging Face的嵌入模型实例
返回:
- 嵌入向量的NumPy数组
"""
# 对文本进行标记化处理
input_ids = tokenizer( input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True texts, return_tensors="pt", padding=True, truncation=True
).input_ids ).input_ids
# 使用模型生成嵌入向量
with torch.no_grad(): with torch.no_grad():
outputs = embed_model(input_ids) outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
@ -703,9 +999,22 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
使用Ollama模型生成给定文本的嵌入向量
参数:
- texts: 需要嵌入的文本列表
- embed_model: 使用的嵌入模型标识符
- **kwargs: 传递给Ollama客户端的其他参数
返回:
- 嵌入向量的列表
"""
embed_text = [] embed_text = []
# 创建Ollama客户端实例
ollama_client = ollama.Client(**kwargs) ollama_client = ollama.Client(**kwargs)
for text in texts: for text in texts:
# 调用模型生成嵌入向量
data = ollama_client.embeddings(model=embed_model, prompt=text) data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"]) embed_text.append(data["embedding"])

File diff suppressed because it is too large Load Diff

View File

@ -20,27 +20,63 @@ from .base import (
BaseVectorStorage, BaseVectorStorage,
) )
@dataclass @dataclass
class JsonKVStorage(BaseKVStorage): class JsonKVStorage(BaseKVStorage):
"""
基于JSON文件的键值存储实现类
继承自BaseKVStorage提供基本的键值存储功能
数据以JSON格式保存在文件系统中
"""
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] """
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") 初始化方法在对象创建后自动调用
self._data = load_json(self._file_name) or {} - 设置工作目录和文件路径
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - 加载已存在的JSON数据
"""
working_dir = self.global_config["working_dir"] # 从全局配置获取工作目录
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") # 构建JSON文件完整路径
self._data = load_json(self._file_name) or {} # 加载JSON文件如果不存在则初始化为空字典
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") # 记录加载数据的数量
async def all_keys(self) -> list[str]: async def all_keys(self) -> list[str]:
"""
获取存储中的所有键
返回值
list[str]: 包含所有键的列表
"""
return list(self._data.keys()) return list(self._data.keys())
async def index_done_callback(self): async def index_done_callback(self):
"""
索引完成后的回调函数
将当前内存中的数据写入JSON文件
"""
write_json(self._data, self._file_name) write_json(self._data, self._file_name)
async def get_by_id(self, id): async def get_by_id(self, id):
"""
通过ID获取单个数据
参数
id: 要查询的数据ID
返回值
查找到的数据如果不存在则返回None
"""
return self._data.get(id, None) return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None): async def get_by_ids(self, ids, fields=None):
"""
批量获取多个ID的数据
参数
ids: ID列表
fields: 可选指定要返回的字段列表
返回值
list: 包含查询结果的列表每个元素对应一个ID的数据
"""
if fields is None: if fields is None:
# 如果未指定字段,返回完整数据
return [self._data.get(id, None) for id in ids] return [self._data.get(id, None) for id in ids]
# 如果指定了字段,只返回指定的字段
return [ return [
( (
{k: v for k, v in self._data[id].items() if k in fields} {k: v for k, v in self._data[id].items() if k in fields}
@ -51,38 +87,80 @@ class JsonKVStorage(BaseKVStorage):
] ]
async def filter_keys(self, data: list[str]) -> set[str]: async def filter_keys(self, data: list[str]) -> set[str]:
"""
过滤出不存在于存储中的键
参数
data: 要检查的键列表
返回值
set[str]: 不存在的键集合
"""
return set([s for s in data if s not in self._data]) return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data} """
self._data.update(left_data) 更新或插入数据
return left_data 参数
data: 要更新/插入的数据字典格式为 {id: {字段: }}
返回值
dict: 实际插入的新数据不包含更新的数据
"""
left_data = {k: v for k, v in data.items() if k not in self._data} # 筛选出新数据
self._data.update(left_data) # 更新存储
return left_data # 返回新插入的数据
async def drop(self): async def drop(self):
"""
清空所有数据
将内存中的数据字典重置为空
"""
self._data = {} self._data = {}
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
"""
向量数据库存储实现类
基于NanoVectorDB实现向量存储和检索功能
支持向量的增删改查操作
"""
# 余弦相似度阈值,用于过滤搜索结果
cosine_better_than_threshold: float = 0.2 cosine_better_than_threshold: float = 0.2
def __post_init__(self): def __post_init__(self):
"""
初始化方法在对象创建后自动调用
设置存储文件路径批处理大小并初始化向量数据库客户端
"""
# 构建向量数据库存储文件路径
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
) )
# 设置批处理大小
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# 初始化向量数据库客户端
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name self.embedding_func.embedding_dim, storage_file=self._client_file_name
) )
# 从配置中获取相似度阈值
self.cosine_better_than_threshold = self.global_config.get( self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold "cosine_better_than_threshold", self.cosine_better_than_threshold
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
"""
更新或插入向量数据
参数:
data: 包含向量数据的字典格式为 {id: {字段: }}
返回值:
list: 插入结果
"""
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
return [] return []
# 准备数据,提取元数据字段
list_data = [ list_data = [
{ {
"__id__": k, "__id__": k,
@ -90,28 +168,49 @@ class NanoVectorDBStorage(BaseVectorStorage):
} }
for k, v in data.items() for k, v in data.items()
] ]
# 提取内容并分批处理
contents = [v["content"] for v in data.values()] contents = [v["content"] for v in data.values()]
batches = [ batches = [
contents[i : i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
# 并行计算向量嵌入
embeddings_list = await asyncio.gather( embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches] *[self.embedding_func(batch) for batch in batches]
) )
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
# 将向量添加到数据中
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
# 执行更新/插入操作
results = self._client.upsert(datas=list_data) results = self._client.upsert(datas=list_data)
return results return results
async def query(self, query: str, top_k=5): async def query(self, query: str, top_k=5):
"""
查询最相似的向量
参数:
query: 查询文本
top_k: 返回的最相似结果数量
返回值:
list: 包含相似度结果的列表
"""
# 计算查询文本的向量表示
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
# 执行向量检索
results = self._client.query( results = self._client.query(
query=embedding, query=embedding,
top_k=top_k, top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold, better_than_threshold=self.cosine_better_than_threshold,
) )
# 格式化返回结果
results = [ results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
] ]
@ -119,12 +218,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
@property @property
def client_storage(self): def client_storage(self):
"""获取底层存储对象"""
return getattr(self._client, "_NanoVectorDB__storage") return getattr(self._client, "_NanoVectorDB__storage")
async def delete_entity(self, entity_name: str): async def delete_entity(self, entity_name: str):
"""
删除指定实体
参数:
entity_name: 要删除的实体名称
"""
try: try:
# 计算实体ID
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")] entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
# 检查并删除实体
if self._client.get(entity_id): if self._client.get(entity_id):
self._client.delete(entity_id) self._client.delete(entity_id)
logger.info(f"Entity {entity_name} have been deleted.") logger.info(f"Entity {entity_name} have been deleted.")
@ -134,7 +241,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error while deleting entity {entity_name}: {e}") logger.error(f"Error while deleting entity {entity_name}: {e}")
async def delete_relation(self, entity_name: str): async def delete_relation(self, entity_name: str):
"""
删除与指定实体相关的所有关系
参数:
entity_name: 实体名称
"""
try: try:
# 查找所有相关关系
relations = [ relations = [
dp dp
for dp in self.client_storage["data"] for dp in self.client_storage["data"]
@ -142,6 +255,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
] ]
ids_to_delete = [relation["__id__"] for relation in relations] ids_to_delete = [relation["__id__"] for relation in relations]
# 执行删除操作
if ids_to_delete: if ids_to_delete:
self._client.delete(ids_to_delete) self._client.delete(ids_to_delete)
logger.info( logger.info(
@ -155,19 +269,38 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
async def index_done_callback(self): async def index_done_callback(self):
"""索引完成后的回调函数,保存数据到存储文件"""
self._client.save() self._client.save()
@dataclass @dataclass
class NetworkXStorage(BaseGraphStorage): class NetworkXStorage(BaseGraphStorage):
"""
基于NetworkX的图存储实现类
提供图数据的存储读取和操作功能
"""
@staticmethod @staticmethod
def load_nx_graph(file_name) -> nx.Graph: def load_nx_graph(file_name) -> nx.Graph:
"""
从文件加载图数据
参数:
file_name: GraphML文件路径
返回值:
nx.Graph: 加载的图对象如果文件不存在返回None
"""
if os.path.exists(file_name): if os.path.exists(file_name):
return nx.read_graphml(file_name) return nx.read_graphml(file_name)
return None return None
@staticmethod @staticmethod
def write_nx_graph(graph: nx.Graph, file_name): def write_nx_graph(graph: nx.Graph, file_name):
"""
将图数据写入文件
参数:
graph: 要保存的图对象
file_name: 保存路径
"""
logger.info( logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
) )
@ -175,40 +308,51 @@ class NetworkXStorage(BaseGraphStorage):
@staticmethod @staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py """
Return the largest connected component of the graph, with nodes and edges sorted in a stable way. 获取图的最大连通分量并确保节点和边的顺序稳定
参考: https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
参数:
graph: 输入图
返回值:
nx.Graph: 处理后的稳定图
""" """
from graspologic.utils import largest_connected_component from graspologic.utils import largest_connected_component
graph = graph.copy() graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph)) graph = cast(nx.Graph, largest_connected_component(graph))
# 对节点标签进行标准化处理
node_mapping = { node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes() node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore }
graph = nx.relabel_nodes(graph, node_mapping) graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph) return NetworkXStorage._stabilize_graph(graph)
@staticmethod @staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph: def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
""" """
确保无向图的关系始终以相同的方式读取
参数:
graph: 输入图
返回值:
nx.Graph: 稳定化后的图
"""
# 根据图的类型创建新图
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
# 对节点进行排序
sorted_nodes = graph.nodes(data=True) sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
# 添加排序后的节点
fixed_graph.add_nodes_from(sorted_nodes) fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True)) edges = list(graph.edges(data=True))
# 对于无向图,确保边的源节点和目标节点有固定顺序
if not graph.is_directed(): if not graph.is_directed():
def _sort_source_target(edge): def _sort_source_target(edge):
source, target, edge_data = edge source, target, edge_data = edge
if source > target: if source > target:
temp = source source, target = target, source
source = target
target = temp
return source, target, edge_data return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges] edges = [_sort_source_target(edge) for edge in edges]
@ -216,12 +360,18 @@ class NetworkXStorage(BaseGraphStorage):
def _get_edge_key(source: Any, target: Any) -> str: def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}" return f"{source} -> {target}"
# 对边进行排序
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges) fixed_graph.add_edges_from(edges)
return fixed_graph return fixed_graph
def __post_init__(self): def __post_init__(self):
"""
初始化方法
- 设置图存储文件路径
- 加载已存在的图数据
- 初始化节点嵌入算法
"""
self._graphml_xml_file = os.path.join( self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml" self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
) )
@ -236,46 +386,56 @@ class NetworkXStorage(BaseGraphStorage):
} }
async def index_done_callback(self): async def index_done_callback(self):
"""索引完成后的回调,保存图数据到文件"""
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
"""检查节点是否存在"""
return self._graph.has_node(node_id) return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""检查边是否存在"""
return self._graph.has_edge(source_node_id, target_node_id) return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
"""获取节点数据"""
return self._graph.nodes.get(node_id) return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""获取节点的度"""
return self._graph.degree(node_id) return self._graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""获取边的度(源节点度 + 目标节点度)"""
return self._graph.degree(src_id) + self._graph.degree(tgt_id) return self._graph.degree(src_id) + self._graph.degree(tgt_id)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict, None]:
"""获取边的数据"""
return self._graph.edges.get((source_node_id, target_node_id)) return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str):
"""获取节点的所有边"""
if self._graph.has_node(source_node_id): if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id)) return list(self._graph.edges(source_node_id))
return None return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""更新或插入节点"""
self._graph.add_node(node_id, **node_data) self._graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ):
"""更新或插入边"""
self._graph.add_edge(source_node_id, target_node_id, **edge_data) self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str): async def delete_node(self, node_id: str):
""" """
Delete a node from the graph based on the specified node_id. 删除指定的节点
参数:
:param node_id: The node_id to delete node_id: 要删除的节点ID
""" """
if self._graph.has_node(node_id): if self._graph.has_node(node_id):
self._graph.remove_node(node_id) self._graph.remove_node(node_id)
@ -284,12 +444,23 @@ class NetworkXStorage(BaseGraphStorage):
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""
使用指定算法进行节点嵌入
参数:
algorithm: 嵌入算法名称
返回值:
tuple: (嵌入向量数组, 节点ID列表)
"""
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
async def _node2vec_embed(self): async def _node2vec_embed(self):
"""
使用node2vec算法进行节点嵌入未使用
返回值:
tuple: (嵌入向量数组, 节点ID列表)
"""
from graspologic import embed from graspologic import embed
embeddings, nodes = embed.node2vec_embed( embeddings, nodes = embed.node2vec_embed(