181 lines
5.5 KiB
Python
181 lines
5.5 KiB
Python
# 从dataclasses模块导入数据类相关工具
|
||
from dataclasses import (
|
||
dataclass, # 数据类装饰器,用于简化类的定义
|
||
field # 字段函数,用于定义特殊的字段属性
|
||
)
|
||
|
||
# 从typing模块导入类型提示工具
|
||
from typing import (
|
||
TypedDict, # 类型化字典,用于定义具有特定类型的字典
|
||
Union, # 联合类型,表示多个可能的类型之一
|
||
Literal, # 字面量类型,用于限定特定的值
|
||
Generic, # 泛型基类,用于创建泛型类
|
||
TypeVar # 类型变量,用于泛型编程
|
||
)
|
||
|
||
# 导入numpy用于数值计算
|
||
import numpy as np
|
||
|
||
# 从本地utils模块导入嵌入函数类
|
||
from .utils import EmbeddingFunc # 用于处理文本嵌入的函数类
|
||
|
||
# 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号
|
||
TextChunkSchema = TypedDict(
|
||
"TextChunkSchema",
|
||
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
||
)
|
||
|
||
# 定义泛型类型变量
|
||
T = TypeVar("T")
|
||
|
||
|
||
@dataclass
|
||
class QueryParam:
|
||
"""查询参数配置类
|
||
|
||
属性:
|
||
mode: 查询模式,可选 'local'(局部)、'global'(全局)、'hybrid'(混合)或'naive'(朴素)
|
||
only_need_context: 是否仅需要上下文
|
||
response_type: 响应类型
|
||
top_k: 检索的top-k项数量
|
||
max_token_for_text_unit: 原始文本块的最大令牌数
|
||
max_token_for_global_context: 关系描述的最大令牌数
|
||
max_token_for_local_context: 实体描述的最大令牌数
|
||
"""
|
||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||
only_need_context: bool = False
|
||
response_type: str = "Multiple Paragraphs"
|
||
top_k: int = 60
|
||
max_token_for_text_unit: int = 4000
|
||
max_token_for_global_context: int = 4000
|
||
max_token_for_local_context: int = 4000
|
||
|
||
|
||
@dataclass
|
||
class StorageNameSpace:
|
||
"""存储命名空间基类
|
||
|
||
属性:
|
||
namespace: 命名空间
|
||
global_config: 全局配置字典
|
||
"""
|
||
namespace: str
|
||
global_config: dict
|
||
|
||
async def index_done_callback(self):
|
||
"""索引完成后的回调函数,用于提交存储操作"""
|
||
pass
|
||
|
||
async def query_done_callback(self):
|
||
"""查询完成后的回调函数,用于提交存储操作"""
|
||
pass
|
||
|
||
|
||
@dataclass
|
||
class BaseVectorStorage(StorageNameSpace):
|
||
"""向量存储基类
|
||
|
||
属性:
|
||
embedding_func: 嵌入函数
|
||
meta_fields: 元数据字段集合
|
||
"""
|
||
embedding_func: EmbeddingFunc
|
||
meta_fields: set = field(default_factory=set)
|
||
|
||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||
"""查询接口"""
|
||
raise NotImplementedError
|
||
|
||
async def upsert(self, data: dict[str, dict]):
|
||
"""更新或插入数据
|
||
使用value中的'content'字段进行嵌入,使用key作为ID
|
||
如果embedding_func为None,则使用value中的'embedding'字段
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
|
||
@dataclass
|
||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||
"""键值存储基类"""
|
||
|
||
async def all_keys(self) -> list[str]:
|
||
"""获取所有键"""
|
||
raise NotImplementedError
|
||
|
||
async def get_by_id(self, id: str) -> Union[T, None]:
|
||
"""通过ID获取值"""
|
||
raise NotImplementedError
|
||
|
||
async def get_by_ids(
|
||
self, ids: list[str], fields: Union[set[str], None] = None
|
||
) -> list[Union[T, None]]:
|
||
"""通过ID列表批量获取值"""
|
||
raise NotImplementedError
|
||
|
||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||
"""返回不存在的键集合"""
|
||
raise NotImplementedError
|
||
|
||
async def upsert(self, data: dict[str, T]):
|
||
"""更新或插入数据"""
|
||
raise NotImplementedError
|
||
|
||
async def drop(self):
|
||
"""删除存储"""
|
||
raise NotImplementedError
|
||
|
||
|
||
@dataclass
|
||
class BaseGraphStorage(StorageNameSpace):
|
||
"""图存储基类"""
|
||
|
||
async def has_node(self, node_id: str) -> bool:
|
||
"""检查节点是否存在"""
|
||
raise NotImplementedError
|
||
|
||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||
"""检查边是否存在"""
|
||
raise NotImplementedError
|
||
|
||
async def node_degree(self, node_id: str) -> int:
|
||
"""获取节点的度"""
|
||
raise NotImplementedError
|
||
|
||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||
"""获取边的度"""
|
||
raise NotImplementedError
|
||
|
||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||
"""获取节点信息"""
|
||
raise NotImplementedError
|
||
|
||
async def get_edge(
|
||
self, source_node_id: str, target_node_id: str
|
||
) -> Union[dict, None]:
|
||
"""获取边信息"""
|
||
raise NotImplementedError
|
||
|
||
async def get_node_edges(
|
||
self, source_node_id: str
|
||
) -> Union[list[tuple[str, str]], None]:
|
||
"""获取节点的所有边"""
|
||
raise NotImplementedError
|
||
|
||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||
"""更新或插入节点"""
|
||
raise NotImplementedError
|
||
|
||
async def upsert_edge(
|
||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||
):
|
||
"""更新或插入边"""
|
||
raise NotImplementedError
|
||
|
||
async def delete_node(self, node_id: str):
|
||
"""删除节点"""
|
||
raise NotImplementedError
|
||
|
||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||
"""节点嵌入(在lightrag中未使用)"""
|
||
raise NotImplementedError("Node embedding is not used in lightrag.")
|