from dataclasses import dataclass, field from typing import TypedDict, Union, Literal, Generic, TypeVar import numpy as np from .utils import EmbeddingFunc # 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号 TextChunkSchema = TypedDict( "TextChunkSchema", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}, ) # 定义泛型类型变量 T = TypeVar("T") @dataclass class QueryParam: """查询参数配置类 属性: mode: 查询模式,可选 'local'(局部)、'global'(全局)、'hybrid'(混合)或'naive'(朴素) only_need_context: 是否仅需要上下文 response_type: 响应类型 top_k: 检索的top-k项数量 max_token_for_text_unit: 原始文本块的最大令牌数 max_token_for_global_context: 关系描述的最大令牌数 max_token_for_local_context: 实体描述的最大令牌数 """ mode: Literal["local", "global", "hybrid", "naive"] = "global" only_need_context: bool = False response_type: str = "Multiple Paragraphs" top_k: int = 60 max_token_for_text_unit: int = 4000 max_token_for_global_context: int = 4000 max_token_for_local_context: int = 4000 @dataclass class StorageNameSpace: """存储命名空间基类 属性: namespace: 命名空间 global_config: 全局配置字典 """ namespace: str global_config: dict async def index_done_callback(self): """索引完成后的回调函数,用于提交存储操作""" pass async def query_done_callback(self): """查询完成后的回调函数,用于提交存储操作""" pass @dataclass class BaseVectorStorage(StorageNameSpace): """向量存储基类 属性: embedding_func: 嵌入函数 meta_fields: 元数据字段集合 """ embedding_func: EmbeddingFunc meta_fields: set = field(default_factory=set) async def query(self, query: str, top_k: int) -> list[dict]: """查询接口""" raise NotImplementedError async def upsert(self, data: dict[str, dict]): """更新或插入数据 使用value中的'content'字段进行嵌入,使用key作为ID 如果embedding_func为None,则使用value中的'embedding'字段 """ raise NotImplementedError @dataclass class BaseKVStorage(Generic[T], StorageNameSpace): """键值存储基类""" async def all_keys(self) -> list[str]: """获取所有键""" raise NotImplementedError async def get_by_id(self, id: str) -> Union[T, None]: """通过ID获取值""" raise NotImplementedError async def get_by_ids( self, ids: list[str], fields: Union[set[str], None] = None ) -> list[Union[T, None]]: """通过ID列表批量获取值""" raise NotImplementedError async def filter_keys(self, data: list[str]) -> set[str]: """返回不存在的键集合""" raise NotImplementedError async def upsert(self, data: dict[str, T]): """更新或插入数据""" raise NotImplementedError async def drop(self): """删除存储""" raise NotImplementedError @dataclass class BaseGraphStorage(StorageNameSpace): """图存储基类""" async def has_node(self, node_id: str) -> bool: """检查节点是否存在""" raise NotImplementedError async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """检查边是否存在""" raise NotImplementedError async def node_degree(self, node_id: str) -> int: """获取节点的度""" raise NotImplementedError async def edge_degree(self, src_id: str, tgt_id: str) -> int: """获取边的度""" raise NotImplementedError async def get_node(self, node_id: str) -> Union[dict, None]: """获取节点信息""" raise NotImplementedError async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """获取边信息""" raise NotImplementedError async def get_node_edges( self, source_node_id: str ) -> Union[list[tuple[str, str]], None]: """获取节点的所有边""" raise NotImplementedError async def upsert_node(self, node_id: str, node_data: dict[str, str]): """更新或插入节点""" raise NotImplementedError async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): """更新或插入边""" raise NotImplementedError async def delete_node(self, node_id: str): """删除节点""" raise NotImplementedError async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: """节点嵌入(在lightrag中未使用)""" raise NotImplementedError("Node embedding is not used in lightrag.")