lightrag-comments/lightrag/base.py

181 lines
5.5 KiB
Python
Raw Permalink Normal View History

2024-11-16 11:59:20 +08:00
# 从dataclasses模块导入数据类相关工具
from dataclasses import (
dataclass, # 数据类装饰器,用于简化类的定义
field # 字段函数,用于定义特殊的字段属性
)
# 从typing模块导入类型提示工具
from typing import (
TypedDict, # 类型化字典,用于定义具有特定类型的字典
Union, # 联合类型,表示多个可能的类型之一
Literal, # 字面量类型,用于限定特定的值
Generic, # 泛型基类,用于创建泛型类
TypeVar # 类型变量,用于泛型编程
)
2024-11-16 11:59:20 +08:00
# 导入numpy用于数值计算
import numpy as np
2024-11-16 11:59:20 +08:00
# 从本地utils模块导入嵌入函数类
from .utils import EmbeddingFunc # 用于处理文本嵌入的函数类
2024-11-16 11:40:45 +08:00
# 定义文本块的数据结构包含令牌数、内容、完整文档ID和块序号
TextChunkSchema = TypedDict(
"TextChunkSchema",
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
)
2024-11-16 11:40:45 +08:00
# 定义泛型类型变量
T = TypeVar("T")
@dataclass
class QueryParam:
2024-11-16 11:40:45 +08:00
"""查询参数配置类
属性:
mode: 查询模式可选 'local'(局部)'global'(全局)'hybrid'(混合)'naive'(朴素)
only_need_context: 是否仅需要上下文
response_type: 响应类型
top_k: 检索的top-k项数量
max_token_for_text_unit: 原始文本块的最大令牌数
max_token_for_global_context: 关系描述的最大令牌数
max_token_for_local_context: 实体描述的最大令牌数
"""
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
top_k: int = 60
max_token_for_text_unit: int = 4000
max_token_for_global_context: int = 4000
max_token_for_local_context: int = 4000
@dataclass
class StorageNameSpace:
2024-11-16 11:40:45 +08:00
"""存储命名空间基类
属性:
namespace: 命名空间
global_config: 全局配置字典
"""
namespace: str
global_config: dict
async def index_done_callback(self):
2024-11-16 11:40:45 +08:00
"""索引完成后的回调函数,用于提交存储操作"""
pass
async def query_done_callback(self):
2024-11-16 11:40:45 +08:00
"""查询完成后的回调函数,用于提交存储操作"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
2024-11-16 11:40:45 +08:00
"""向量存储基类
属性:
embedding_func: 嵌入函数
meta_fields: 元数据字段集合
"""
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
2024-11-16 11:40:45 +08:00
"""查询接口"""
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
2024-11-16 11:40:45 +08:00
"""更新或插入数据
使用value中的'content'字段进行嵌入使用key作为ID
如果embedding_func为None则使用value中的'embedding'字段
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
2024-11-16 11:40:45 +08:00
"""键值存储基类"""
async def all_keys(self) -> list[str]:
2024-11-16 11:40:45 +08:00
"""获取所有键"""
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
2024-11-16 11:40:45 +08:00
"""通过ID获取值"""
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
2024-11-16 11:40:45 +08:00
"""通过ID列表批量获取值"""
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
2024-11-16 11:40:45 +08:00
"""返回不存在的键集合"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
2024-11-16 11:40:45 +08:00
"""更新或插入数据"""
raise NotImplementedError
async def drop(self):
2024-11-16 11:40:45 +08:00
"""删除存储"""
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
2024-11-16 11:40:45 +08:00
"""图存储基类"""
async def has_node(self, node_id: str) -> bool:
2024-11-16 11:40:45 +08:00
"""检查节点是否存在"""
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
2024-11-16 11:40:45 +08:00
"""检查边是否存在"""
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
2024-11-16 11:40:45 +08:00
"""获取节点的度"""
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
2024-11-16 11:40:45 +08:00
"""获取边的度"""
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
2024-11-16 11:40:45 +08:00
"""获取节点信息"""
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
2024-11-16 11:40:45 +08:00
"""获取边信息"""
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
2024-11-16 11:40:45 +08:00
"""获取节点的所有边"""
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
2024-11-16 11:40:45 +08:00
"""更新或插入节点"""
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
2024-11-16 11:40:45 +08:00
"""更新或插入边"""
raise NotImplementedError
async def delete_node(self, node_id: str):
2024-11-16 11:40:45 +08:00
"""删除节点"""
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
2024-11-16 11:40:45 +08:00
"""节点嵌入在lightrag中未使用"""
raise NotImplementedError("Node embedding is not used in lightrag.")