lightrag-comments/lightrag/base.py

167 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar
import numpy as np
from .utils import EmbeddingFunc
# 定义文本块的数据结构包含令牌数、内容、完整文档ID和块序号
TextChunkSchema = TypedDict(
"TextChunkSchema",
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
)
# 定义泛型类型变量
T = TypeVar("T")
@dataclass
class QueryParam:
"""查询参数配置类
属性:
mode: 查询模式,可选 'local'(局部)、'global'(全局)、'hybrid'(混合)或'naive'(朴素)
only_need_context: 是否仅需要上下文
response_type: 响应类型
top_k: 检索的top-k项数量
max_token_for_text_unit: 原始文本块的最大令牌数
max_token_for_global_context: 关系描述的最大令牌数
max_token_for_local_context: 实体描述的最大令牌数
"""
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
top_k: int = 60
max_token_for_text_unit: int = 4000
max_token_for_global_context: int = 4000
max_token_for_local_context: int = 4000
@dataclass
class StorageNameSpace:
"""存储命名空间基类
属性:
namespace: 命名空间
global_config: 全局配置字典
"""
namespace: str
global_config: dict
async def index_done_callback(self):
"""索引完成后的回调函数,用于提交存储操作"""
pass
async def query_done_callback(self):
"""查询完成后的回调函数,用于提交存储操作"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
"""向量存储基类
属性:
embedding_func: 嵌入函数
meta_fields: 元数据字段集合
"""
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
"""查询接口"""
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""更新或插入数据
使用value中的'content'字段进行嵌入使用key作为ID
如果embedding_func为None则使用value中的'embedding'字段
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
"""键值存储基类"""
async def all_keys(self) -> list[str]:
"""获取所有键"""
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
"""通过ID获取值"""
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
"""通过ID列表批量获取值"""
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""返回不存在的键集合"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
"""更新或插入数据"""
raise NotImplementedError
async def drop(self):
"""删除存储"""
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
"""图存储基类"""
async def has_node(self, node_id: str) -> bool:
"""检查节点是否存在"""
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""检查边是否存在"""
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
"""获取节点的度"""
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""获取边的度"""
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
"""获取节点信息"""
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""获取边信息"""
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
"""获取节点的所有边"""
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""更新或插入节点"""
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
"""更新或插入边"""
raise NotImplementedError
async def delete_node(self, node_id: str):
"""删除节点"""
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""节点嵌入在lightrag中未使用"""
raise NotImplementedError("Node embedding is not used in lightrag.")