使用cursor加注释。
This commit is contained in:
parent
6969e4afc1
commit
3970233c36
@ -1,9 +1,23 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Union, Literal, Generic, TypeVar
|
||||
# 从dataclasses模块导入数据类相关工具
|
||||
from dataclasses import (
|
||||
dataclass, # 数据类装饰器,用于简化类的定义
|
||||
field # 字段函数,用于定义特殊的字段属性
|
||||
)
|
||||
|
||||
# 从typing模块导入类型提示工具
|
||||
from typing import (
|
||||
TypedDict, # 类型化字典,用于定义具有特定类型的字典
|
||||
Union, # 联合类型,表示多个可能的类型之一
|
||||
Literal, # 字面量类型,用于限定特定的值
|
||||
Generic, # 泛型基类,用于创建泛型类
|
||||
TypeVar # 类型变量,用于泛型编程
|
||||
)
|
||||
|
||||
# 导入numpy用于数值计算
|
||||
import numpy as np
|
||||
|
||||
from .utils import EmbeddingFunc
|
||||
# 从本地utils模块导入嵌入函数类
|
||||
from .utils import EmbeddingFunc # 用于处理文本嵌入的函数类
|
||||
|
||||
# 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号
|
||||
TextChunkSchema = TypedDict(
|
||||
|
@ -1,103 +1,181 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, Tuple, List, Dict
|
||||
import inspect
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
# 标准库导入
|
||||
import asyncio # 异步IO支持
|
||||
import os # 操作系统接口,用于环境变量访问
|
||||
import inspect # 用于运行时检查Python对象
|
||||
|
||||
# 数据类和类型提示相关导入
|
||||
from dataclasses import dataclass # 数据类装饰器
|
||||
from typing import (
|
||||
Any, # 任意类型
|
||||
Union, # 联合类型
|
||||
Tuple, # 元组类型
|
||||
List, # 列表类型
|
||||
Dict # 字典类型
|
||||
)
|
||||
|
||||
# 本地模块导入
|
||||
from lightrag.utils import logger # 日志记录器
|
||||
from ..base import BaseGraphStorage # 图存储基类
|
||||
|
||||
# Neo4j相关导入
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase, # Neo4j异步图数据库驱动
|
||||
exceptions as neo4jExceptions, # Neo4j异常类
|
||||
AsyncDriver, # Neo4j异步驱动接口
|
||||
AsyncManagedTransaction, # Neo4j异步事务管理
|
||||
)
|
||||
|
||||
# 重试机制相关导入
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
retry, # 重试装饰器
|
||||
stop_after_attempt, # 最大重试次数限制
|
||||
wait_exponential, # 指数退避等待策略
|
||||
retry_if_exception_type, # 基于异常类型的重试条件
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
"""Neo4j图数据库存储实现类"""
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
"""加载NetworkX图的静态方法(生产环境中未使用)
|
||||
|
||||
Args:
|
||||
file_name: 图文件名
|
||||
"""
|
||||
print("no preloading of graph with neo4j in production")
|
||||
|
||||
def __init__(self, namespace, global_config):
|
||||
"""初始化Neo4j存储实例
|
||||
|
||||
Args:
|
||||
namespace: 命名空间
|
||||
global_config: 全局配置
|
||||
|
||||
Note:
|
||||
从环境变量中读取Neo4j连接信息并初始化驱动
|
||||
"""
|
||||
# 调用父类初始化
|
||||
super().__init__(namespace=namespace, global_config=global_config)
|
||||
# 初始化驱动相关属性
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
self._driver_lock = asyncio.Lock() # 异步锁,用于并发控制
|
||||
|
||||
# 从环境变量获取Neo4j连接信息
|
||||
URI = os.environ["NEO4J_URI"]
|
||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||
|
||||
# 初始化Neo4j异步驱动
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
URI, auth=(USERNAME, PASSWORD)
|
||||
)
|
||||
return None
|
||||
|
||||
def __post_init__(self):
|
||||
"""数据类后初始化方法,设置节点嵌入算法"""
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
self._driver = None
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""异步上下文管理器的退出方法"""
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""索引完成回调方法"""
|
||||
print("KG successfully indexed.")
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""检查节点是否存在
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
bool: 节点是否存在
|
||||
"""
|
||||
# 清理节点ID中的引号
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
# 构建Cypher查询
|
||||
query = (
|
||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||
)
|
||||
# 执行查询
|
||||
result = await session.run(query)
|
||||
single_result = await result.single()
|
||||
# 记录调试日志
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
||||
)
|
||||
return single_result["node_exists"]
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""检查边是否存在
|
||||
|
||||
Args:
|
||||
source_node_id: 源节点ID
|
||||
target_node_id: 目标节点ID
|
||||
|
||||
Returns:
|
||||
bool: 边是否存在
|
||||
"""
|
||||
# 清理节点ID中的引号
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
# 构建Cypher查询
|
||||
query = (
|
||||
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
)
|
||||
# 执行查询
|
||||
result = await session.run(query)
|
||||
single_result = await result.single()
|
||||
# 记录调试日志
|
||||
logger.debug(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
||||
)
|
||||
return single_result["edgeExists"]
|
||||
|
||||
def close(self):
|
||||
"""同步关闭方法(注意:这是一个缩进错误,应该与其他方法对齐)"""
|
||||
self._driver.close()
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""获取节点信息
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
dict: 节点属性字典,如果节点不存在则返回None
|
||||
"""
|
||||
async with self._driver.session() as session:
|
||||
# 清理节点ID中的引号
|
||||
entity_name_label = node_id.strip('"')
|
||||
# 构建Cypher查询
|
||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||
# 执行查询
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
if record:
|
||||
# 提取节点数据并转换为字典
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
# 记录调试日志
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
|
||||
)
|
||||
@ -105,9 +183,19 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""获取节点的度(与节点相连的边的数量)
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
int: 节点的度,如果节点不存在则返回None
|
||||
"""
|
||||
# 清理节点ID中的引号
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
# 构建Cypher查询,计算节点的总边数
|
||||
query = f"""
|
||||
MATCH (n:`{entity_name_label}`)
|
||||
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
||||
@ -116,6 +204,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
record = await result.single()
|
||||
if record:
|
||||
edge_count = record["totalEdgeCount"]
|
||||
# 记录调试日志
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
|
||||
)
|
||||
@ -124,15 +213,28 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""计算边的度(源节点和目标节点的度之和)
|
||||
|
||||
Args:
|
||||
src_id: 源节点ID
|
||||
tgt_id: 目标节点ID
|
||||
|
||||
Returns:
|
||||
int: 边的度(两个节点的度之和)
|
||||
"""
|
||||
# 清理节点ID中的引号
|
||||
entity_name_label_source = src_id.strip('"')
|
||||
entity_name_label_target = tgt_id.strip('"')
|
||||
|
||||
# 获取源节点和目标节点的度
|
||||
src_degree = await self.node_degree(entity_name_label_source)
|
||||
trg_degree = await self.node_degree(entity_name_label_target)
|
||||
|
||||
# Convert None to 0 for addition
|
||||
# 将None转换为0以进行加法运算
|
||||
src_degree = 0 if src_degree is None else src_degree
|
||||
trg_degree = 0 if trg_degree is None else trg_degree
|
||||
|
||||
# 计算总度数并记录日志
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
|
||||
@ -142,19 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""获取两个节点之间的边的属性
|
||||
|
||||
Args:
|
||||
source_node_id: 源节点ID
|
||||
target_node_id: 目标节点ID
|
||||
|
||||
Returns:
|
||||
dict: 边的属性字典,如果边不存在则返回None
|
||||
"""
|
||||
# 清理节点ID中的引号
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
|
||||
Args:
|
||||
source_node_label (str): Label of the source nodes
|
||||
target_node_label (str): Label of the target nodes
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
async with self._driver.session() as session:
|
||||
# 构建Cypher查询,获取边的属性
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
@ -164,9 +268,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
entity_name_label_target=entity_name_label_target,
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
if record:
|
||||
# 转换结果为字典并记录日志
|
||||
result = dict(record["edge_properties"])
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||
@ -176,40 +282,49 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
"""获取指定节点的所有边
|
||||
|
||||
Args:
|
||||
source_node_id: 源节点ID
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 边列表,每个元素为(源节点标签, 目标节点标签)的元组
|
||||
"""
|
||||
node_label = source_node_id.strip('"')
|
||||
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: List of dictionaries containing edge information
|
||||
"""
|
||||
# 构建Cypher查询,获取节点及其所有关系
|
||||
query = f"""MATCH (n:`{node_label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected"""
|
||||
async with self._driver.session() as session:
|
||||
results = await session.run(query)
|
||||
edges = []
|
||||
# 异步迭代处理查询结果
|
||||
async for record in results:
|
||||
source_node = record["n"]
|
||||
connected_node = record["connected"]
|
||||
|
||||
# 获取源节点标签(取第一个标签)
|
||||
source_label = (
|
||||
list(source_node.labels)[0] if source_node.labels else None
|
||||
)
|
||||
# 获取目标节点标签(取第一个标签)
|
||||
target_label = (
|
||||
list(connected_node.labels)[0]
|
||||
if connected_node and connected_node.labels
|
||||
else None
|
||||
)
|
||||
|
||||
# 如果源节点和目标节点都有标签,则添加到边列表
|
||||
if source_label and target_label:
|
||||
edges.append((source_label, target_label))
|
||||
|
||||
return edges
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
stop=stop_after_attempt(3), # 最多重试3次
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
|
||||
retry=retry_if_exception_type( # 指定需要重试的异常类型
|
||||
(
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
@ -218,17 +333,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
"""
|
||||
Upsert a node in the Neo4j database.
|
||||
|
||||
"""更新或插入节点
|
||||
|
||||
Args:
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
node_id: 节点的唯一标识符(用作标签)
|
||||
node_data: 节点属性字典
|
||||
"""
|
||||
label = node_id.strip('"')
|
||||
properties = node_data
|
||||
|
||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||
"""执行节点更新/插入的内部函数"""
|
||||
query = f"""
|
||||
MERGE (n:`{label}`)
|
||||
SET n += $properties
|
||||
@ -246,9 +361,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
raise
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
stop=stop_after_attempt(3), # 最多重试3次
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10), # 指数退避等待
|
||||
retry=retry_if_exception_type( # 指定需要重试的异常类型
|
||||
(
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
@ -259,19 +374,19 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
"""更新或插入边及其属性
|
||||
|
||||
Args:
|
||||
source_node_id (str): Label of the source node (used as identifier)
|
||||
target_node_id (str): Label of the target node (used as identifier)
|
||||
edge_data (dict): Dictionary of properties to set on the edge
|
||||
source_node_id: 源节点标签(用作标识符)
|
||||
target_node_id: 目标节点标签(用作标识符)
|
||||
edge_data: 边属性字典
|
||||
"""
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
edge_properties = edge_data
|
||||
|
||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||
"""执行边更新/插入的内部函数"""
|
||||
query = f"""
|
||||
MATCH (source:`{source_node_label}`)
|
||||
WITH source
|
||||
@ -293,4 +408,5 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
raise
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
"""节点嵌入方法(未实际使用)"""
|
||||
print("Implemented but never called.")
|
||||
|
@ -3,12 +3,30 @@ LightRAG - 轻量级检索增强生成系统
|
||||
该模块实现了一个基于图的文档检索和问答系统,支持文档的存储、检索和知识图谱构建
|
||||
"""
|
||||
|
||||
# 导入异步IO模块,用于处理异步编程
|
||||
import asyncio
|
||||
|
||||
# 导入操作系统接口模块,用于处理文件路径和环境变量
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
# 从dataclasses模块导入数据类相关工具
|
||||
from dataclasses import (
|
||||
asdict, # 将数据类实例转换为字典的函数
|
||||
dataclass, # 数据类装饰器
|
||||
field, # 用于定义数据类字段的函数
|
||||
)
|
||||
|
||||
# 导入日期时间处理模块
|
||||
from datetime import datetime
|
||||
|
||||
# 从functools导入partial函数,用于创建偏函数
|
||||
from functools import partial
|
||||
from typing import Type, cast
|
||||
|
||||
# 从typing模块导入类型提示工具
|
||||
from typing import (
|
||||
Type, # 用于类型注解中表示类型的类型
|
||||
cast, # 用于类型转换的函数
|
||||
)
|
||||
|
||||
# 导入LLM相关功能
|
||||
from .llm import (
|
||||
|
@ -1,35 +1,51 @@
|
||||
import os
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import ollama
|
||||
# 标准库导入
|
||||
import os # 操作系统接口
|
||||
import copy # 深浅拷贝功能
|
||||
from functools import lru_cache # 最近最少使用缓存装饰器
|
||||
import json # JSON数据处理
|
||||
import base64 # Base64编解码
|
||||
import struct # 处理二进制数据结构
|
||||
|
||||
# 第三方异步库
|
||||
import aioboto3 # AWS SDK的异步版本
|
||||
import aiohttp # 异步HTTP客户端/服务器
|
||||
import ollama # Ollama API客户端
|
||||
|
||||
# 数值计算和机器学习库
|
||||
import numpy as np # 数值计算库
|
||||
import torch # PyTorch深度学习框架
|
||||
from transformers import ( # Hugging Face转换器库
|
||||
AutoTokenizer, # 自动分词器
|
||||
AutoModelForCausalLM, # 自动因果语言模型
|
||||
)
|
||||
|
||||
# OpenAI相关导入
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
AsyncAzureOpenAI,
|
||||
AsyncOpenAI, # OpenAI异步客户端
|
||||
APIConnectionError, # API连接错误
|
||||
RateLimitError, # 速率限制错误
|
||||
Timeout, # 超时错误
|
||||
AsyncAzureOpenAI, # Azure OpenAI异步客户端
|
||||
)
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
# 重试机制相关导入
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
retry, # 重试装饰器
|
||||
stop_after_attempt, # 最大重试次数
|
||||
wait_exponential, # 指数退避等待
|
||||
retry_if_exception_type, # 基于异常类型的重试条件
|
||||
)
|
||||
|
||||
# 数据验证和类型提示
|
||||
from pydantic import BaseModel, Field # 数据验证模型
|
||||
from typing import List, Dict, Callable, Any # 类型提示
|
||||
|
||||
# 本地模块导入
|
||||
from .base import BaseKVStorage # 键值存储基类
|
||||
from .utils import (
|
||||
compute_args_hash, # 计算参数哈希值
|
||||
wrap_embedding_func_with_attrs, # 包装嵌入函数的装饰器
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Callable, Any
|
||||
from .base import BaseKVStorage
|
||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
|
||||
# 禁用并行化以避免tokenizers的并行化导致的问题
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
@ -1,30 +1,43 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Union
|
||||
from collections import Counter, defaultdict
|
||||
import warnings
|
||||
# 标准库导入
|
||||
import asyncio # 异步IO支持
|
||||
import json # JSON数据处理
|
||||
import re # 正则表达式支持
|
||||
from typing import Union # 类型提示:联合类型
|
||||
from collections import (
|
||||
Counter, # 计数器集合类
|
||||
defaultdict # 带默认值的字典
|
||||
)
|
||||
import warnings # 警告控制
|
||||
|
||||
# 从本地utils模块导入工具函数
|
||||
from .utils import (
|
||||
logger,
|
||||
clean_str,
|
||||
compute_mdhash_id,
|
||||
decode_tokens_by_tiktoken,
|
||||
encode_string_by_tiktoken,
|
||||
is_float_regex,
|
||||
list_of_list_to_csv,
|
||||
pack_user_ass_to_openai_messages,
|
||||
split_string_by_multi_markers,
|
||||
truncate_list_by_token_size,
|
||||
process_combine_contexts,
|
||||
logger, # 日志记录器
|
||||
clean_str, # 字符串清理函数
|
||||
compute_mdhash_id, # 计算MD5哈希ID
|
||||
decode_tokens_by_tiktoken, # tiktoken解码函数
|
||||
encode_string_by_tiktoken, # tiktoken编码函数
|
||||
is_float_regex, # 浮点数检查函数
|
||||
list_of_list_to_csv, # 列表转CSV函数
|
||||
pack_user_ass_to_openai_messages, # OpenAI消息打包函数
|
||||
split_string_by_multi_markers, # 多标记字符串分割函数
|
||||
truncate_list_by_token_size, # 基于token大小截断列表
|
||||
process_combine_contexts, # 上下文合并处理函数
|
||||
)
|
||||
|
||||
# 从本地base模块导入基础类
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
TextChunkSchema,
|
||||
QueryParam,
|
||||
BaseGraphStorage, # 图存储基类
|
||||
BaseKVStorage, # 键值存储基类
|
||||
BaseVectorStorage, # 向量存储基类
|
||||
TextChunkSchema, # 文本块模式定义
|
||||
QueryParam, # 查询参数类
|
||||
)
|
||||
|
||||
# 从本地prompt模块导入提示相关常量
|
||||
from .prompt import (
|
||||
GRAPH_FIELD_SEP, # 图字段分隔符
|
||||
PROMPTS # 提示模板集合
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||
|
||||
|
||||
def chunking_by_token_size(
|
||||
|
@ -1,23 +1,34 @@
|
||||
import asyncio
|
||||
import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
# 标准库导入
|
||||
import asyncio # 异步IO支持
|
||||
import html # HTML实体编解码
|
||||
import os # 操作系统接口,用于文件和路径操作
|
||||
|
||||
from .utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
compute_mdhash_id,
|
||||
# 数据类和类型提示相关导入
|
||||
from dataclasses import dataclass # 数据类装饰器
|
||||
from typing import (
|
||||
Any, # 任意类型
|
||||
Union, # 联合类型
|
||||
cast # 类型转换函数
|
||||
)
|
||||
|
||||
# 第三方库导入
|
||||
import networkx as nx # 图数据处理库
|
||||
import numpy as np # 数值计算库
|
||||
from nano_vectordb import NanoVectorDB # 向量数据库
|
||||
|
||||
# 从本地utils模块导入工具函数
|
||||
from .utils import (
|
||||
logger, # 日志记录器
|
||||
load_json, # JSON文件加载函数
|
||||
write_json, # JSON文件写入函数
|
||||
compute_mdhash_id, # 计算MD5哈希ID的函数
|
||||
)
|
||||
|
||||
# 从本地base模块导入基础存储类
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
BaseGraphStorage, # 图存储基类
|
||||
BaseKVStorage, # 键值存储基类
|
||||
BaseVectorStorage, # 向量存储基类
|
||||
)
|
||||
|
||||
@dataclass
|
||||
|
@ -1,19 +1,29 @@
|
||||
import asyncio
|
||||
import html
|
||||
import io
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Union, List
|
||||
import xml.etree.ElementTree as ET
|
||||
# 标准库导入 - 异步和IO操作
|
||||
import asyncio # 异步IO支持
|
||||
import html # HTML实体编解码
|
||||
import io # 内存IO操作
|
||||
import csv # CSV文件处理
|
||||
import json # JSON数据处理
|
||||
import logging # 日志记录
|
||||
import os # 操作系统接口
|
||||
import re # 正则表达式
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
# 标准库导入 - 数据结构和工具
|
||||
from dataclasses import dataclass # 数据类装饰器
|
||||
from functools import wraps # 装饰器工具
|
||||
from hashlib import md5 # MD5哈希算法
|
||||
from typing import ( # 类型提示
|
||||
Any, # 任意类型
|
||||
Union, # 联合类型
|
||||
List # 列表类型
|
||||
)
|
||||
|
||||
# XML处理
|
||||
import xml.etree.ElementTree as ET # XML解析和处理
|
||||
|
||||
# 第三方库导入
|
||||
import numpy as np # 数值计算库
|
||||
import tiktoken # OpenAI的分词器
|
||||
|
||||
# 全局编码器变量
|
||||
ENCODER = None
|
||||
|
Loading…
Reference in New Issue
Block a user