lightrag-comments/lightrag/storage.py

483 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 标准库导入
import asyncio # 异步IO支持
import html # HTML实体编解码
import os # 操作系统接口,用于文件和路径操作
# 数据类和类型提示相关导入
from dataclasses import dataclass # 数据类装饰器
from typing import (
Any, # 任意类型
Union, # 联合类型
cast # 类型转换函数
)
# 第三方库导入
import networkx as nx # 图数据处理库
import numpy as np # 数值计算库
from nano_vectordb import NanoVectorDB # 向量数据库
# 从本地utils模块导入工具函数
from .utils import (
logger, # 日志记录器
load_json, # JSON文件加载函数
write_json, # JSON文件写入函数
compute_mdhash_id, # 计算MD5哈希ID的函数
)
# 从本地base模块导入基础存储类
from .base import (
BaseGraphStorage, # 图存储基类
BaseKVStorage, # 键值存储基类
BaseVectorStorage, # 向量存储基类
)
@dataclass
class JsonKVStorage(BaseKVStorage):
"""
基于JSON文件的键值存储实现类
继承自BaseKVStorage提供基本的键值存储功能
数据以JSON格式保存在文件系统中
"""
def __post_init__(self):
"""
初始化方法,在对象创建后自动调用
- 设置工作目录和文件路径
- 加载已存在的JSON数据
"""
working_dir = self.global_config["working_dir"] # 从全局配置获取工作目录
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") # 构建JSON文件完整路径
self._data = load_json(self._file_name) or {} # 加载JSON文件如果不存在则初始化为空字典
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") # 记录加载数据的数量
async def all_keys(self) -> list[str]:
"""
获取存储中的所有键
返回值:
list[str]: 包含所有键的列表
"""
return list(self._data.keys())
async def index_done_callback(self):
"""
索引完成后的回调函数
将当前内存中的数据写入JSON文件
"""
write_json(self._data, self._file_name)
async def get_by_id(self, id):
"""
通过ID获取单个数据
参数:
id: 要查询的数据ID
返回值:
查找到的数据如果不存在则返回None
"""
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
"""
批量获取多个ID的数据
参数:
ids: ID列表
fields: 可选,指定要返回的字段列表
返回值:
list: 包含查询结果的列表每个元素对应一个ID的数据
"""
if fields is None:
# 如果未指定字段,返回完整数据
return [self._data.get(id, None) for id in ids]
# 如果指定了字段,只返回指定的字段
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
"""
过滤出不存在于存储中的键
参数:
data: 要检查的键列表
返回值:
set[str]: 不存在的键集合
"""
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
"""
更新或插入数据
参数:
data: 要更新/插入的数据字典,格式为 {id: {字段: 值}}
返回值:
dict: 实际插入的新数据(不包含更新的数据)
"""
left_data = {k: v for k, v in data.items() if k not in self._data} # 筛选出新数据
self._data.update(left_data) # 更新存储
return left_data # 返回新插入的数据
async def drop(self):
"""
清空所有数据
将内存中的数据字典重置为空
"""
self._data = {}
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
"""
向量数据库存储实现类
基于NanoVectorDB实现向量存储和检索功能
支持向量的增删改查操作
"""
# 余弦相似度阈值,用于过滤搜索结果
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
"""
初始化方法,在对象创建后自动调用
设置存储文件路径、批处理大小,并初始化向量数据库客户端
"""
# 构建向量数据库存储文件路径
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
# 设置批处理大小
self._max_batch_size = self.global_config["embedding_batch_num"]
# 初始化向量数据库客户端
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
# 从配置中获取相似度阈值
self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
"""
更新或插入向量数据
参数:
data: 包含向量数据的字典,格式为 {id: {字段: 值}}
返回值:
list: 插入结果
"""
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
# 准备数据,提取元数据字段
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
# 提取内容并分批处理
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
# 并行计算向量嵌入
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
# 将向量添加到数据中
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# 执行更新/插入操作
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5):
"""
查询最相似的向量
参数:
query: 查询文本
top_k: 返回的最相似结果数量
返回值:
list: 包含相似度结果的列表
"""
# 计算查询文本的向量表示
embedding = await self.embedding_func([query])
embedding = embedding[0]
# 执行向量检索
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
# 格式化返回结果
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
return results
@property
def client_storage(self):
"""获取底层存储对象"""
return getattr(self._client, "_NanoVectorDB__storage")
async def delete_entity(self, entity_name: str):
"""
删除指定实体
参数:
entity_name: 要删除的实体名称
"""
try:
# 计算实体ID
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
# 检查并删除实体
if self._client.get(entity_id):
self._client.delete(entity_id)
logger.info(f"Entity {entity_name} have been deleted.")
else:
logger.info(f"No entity found with name {entity_name}.")
except Exception as e:
logger.error(f"Error while deleting entity {entity_name}: {e}")
async def delete_relation(self, entity_name: str):
"""
删除与指定实体相关的所有关系
参数:
entity_name: 实体名称
"""
try:
# 查找所有相关关系
relations = [
dp
for dp in self.client_storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
ids_to_delete = [relation["__id__"] for relation in relations]
# 执行删除操作
if ids_to_delete:
self._client.delete(ids_to_delete)
logger.info(
f"All relations related to entity {entity_name} have been deleted."
)
else:
logger.info(f"No relations found for entity {entity_name}.")
except Exception as e:
logger.error(
f"Error while deleting relations for entity {entity_name}: {e}"
)
async def index_done_callback(self):
"""索引完成后的回调函数,保存数据到存储文件"""
self._client.save()
@dataclass
class NetworkXStorage(BaseGraphStorage):
"""
基于NetworkX的图存储实现类
提供图数据的存储、读取和操作功能
"""
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
"""
从文件加载图数据
参数:
file_name: GraphML文件路径
返回值:
nx.Graph: 加载的图对象如果文件不存在返回None
"""
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
"""
将图数据写入文件
参数:
graph: 要保存的图对象
file_name: 保存路径
"""
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""
获取图的最大连通分量,并确保节点和边的顺序稳定
参考: https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
参数:
graph: 输入图
返回值:
nx.Graph: 处理后的稳定图
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
# 对节点标签进行标准化处理
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
}
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""
确保无向图的关系始终以相同的方式读取
参数:
graph: 输入图
返回值:
nx.Graph: 稳定化后的图
"""
# 根据图的类型创建新图
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
# 对节点进行排序
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
# 添加排序后的节点
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
# 对于无向图,确保边的源节点和目标节点有固定顺序
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
source, target = target, source
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
# 对边进行排序
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
"""
初始化方法
- 设置图存储文件路径
- 加载已存在的图数据
- 初始化节点嵌入算法
"""
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
"""索引完成后的回调,保存图数据到文件"""
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
"""检查节点是否存在"""
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""检查边是否存在"""
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
"""获取节点数据"""
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
"""获取节点的度"""
return self._graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""获取边的度(源节点度 + 目标节点度)"""
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""获取边的数据"""
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
"""获取节点的所有边"""
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""更新或插入节点"""
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
"""更新或插入边"""
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
删除指定的节点
参数:
node_id: 要删除的节点ID
"""
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""
使用指定算法进行节点嵌入
参数:
algorithm: 嵌入算法名称
返回值:
tuple: (嵌入向量数组, 节点ID列表)
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def _node2vec_embed(self):
"""
使用node2vec算法进行节点嵌入未使用
返回值:
tuple: (嵌入向量数组, 节点ID列表)
"""
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids