使用cursor加注释。
This commit is contained in:
parent
c8ee7286cb
commit
6969e4afc1
@ -5,119 +5,162 @@ import numpy as np
|
|||||||
|
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
|
|
||||||
|
# 定义文本块的数据结构,包含令牌数、内容、完整文档ID和块序号
|
||||||
TextChunkSchema = TypedDict(
|
TextChunkSchema = TypedDict(
|
||||||
"TextChunkSchema",
|
"TextChunkSchema",
|
||||||
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 定义泛型类型变量
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueryParam:
|
class QueryParam:
|
||||||
|
"""查询参数配置类
|
||||||
|
|
||||||
|
属性:
|
||||||
|
mode: 查询模式,可选 'local'(局部)、'global'(全局)、'hybrid'(混合)或'naive'(朴素)
|
||||||
|
only_need_context: 是否仅需要上下文
|
||||||
|
response_type: 响应类型
|
||||||
|
top_k: 检索的top-k项数量
|
||||||
|
max_token_for_text_unit: 原始文本块的最大令牌数
|
||||||
|
max_token_for_global_context: 关系描述的最大令牌数
|
||||||
|
max_token_for_local_context: 实体描述的最大令牌数
|
||||||
|
"""
|
||||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
response_type: str = "Multiple Paragraphs"
|
response_type: str = "Multiple Paragraphs"
|
||||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
|
||||||
top_k: int = 60
|
top_k: int = 60
|
||||||
# Number of tokens for the original chunks.
|
|
||||||
max_token_for_text_unit: int = 4000
|
max_token_for_text_unit: int = 4000
|
||||||
# Number of tokens for the relationship descriptions
|
|
||||||
max_token_for_global_context: int = 4000
|
max_token_for_global_context: int = 4000
|
||||||
# Number of tokens for the entity descriptions
|
|
||||||
max_token_for_local_context: int = 4000
|
max_token_for_local_context: int = 4000
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageNameSpace:
|
class StorageNameSpace:
|
||||||
|
"""存储命名空间基类
|
||||||
|
|
||||||
|
属性:
|
||||||
|
namespace: 命名空间
|
||||||
|
global_config: 全局配置字典
|
||||||
|
"""
|
||||||
namespace: str
|
namespace: str
|
||||||
global_config: dict
|
global_config: dict
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""commit the storage operations after indexing"""
|
"""索引完成后的回调函数,用于提交存储操作"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def query_done_callback(self):
|
async def query_done_callback(self):
|
||||||
"""commit the storage operations after querying"""
|
"""查询完成后的回调函数,用于提交存储操作"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseVectorStorage(StorageNameSpace):
|
class BaseVectorStorage(StorageNameSpace):
|
||||||
|
"""向量存储基类
|
||||||
|
|
||||||
|
属性:
|
||||||
|
embedding_func: 嵌入函数
|
||||||
|
meta_fields: 元数据字段集合
|
||||||
|
"""
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
meta_fields: set = field(default_factory=set)
|
meta_fields: set = field(default_factory=set)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||||
|
"""查询接口"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict]):
|
||||||
"""Use 'content' field from value for embedding, use key as id.
|
"""更新或插入数据
|
||||||
If embedding_func is None, use 'embedding' field from value
|
使用value中的'content'字段进行嵌入,使用key作为ID
|
||||||
|
如果embedding_func为None,则使用value中的'embedding'字段
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||||
|
"""键值存储基类"""
|
||||||
|
|
||||||
async def all_keys(self) -> list[str]:
|
async def all_keys(self) -> list[str]:
|
||||||
|
"""获取所有键"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[T, None]:
|
async def get_by_id(self, id: str) -> Union[T, None]:
|
||||||
|
"""通过ID获取值"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_by_ids(
|
async def get_by_ids(
|
||||||
self, ids: list[str], fields: Union[set[str], None] = None
|
self, ids: list[str], fields: Union[set[str], None] = None
|
||||||
) -> list[Union[T, None]]:
|
) -> list[Union[T, None]]:
|
||||||
|
"""通过ID列表批量获取值"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||||
"""return un-exist keys"""
|
"""返回不存在的键集合"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, T]):
|
async def upsert(self, data: dict[str, T]):
|
||||||
|
"""更新或插入数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def drop(self):
|
async def drop(self):
|
||||||
|
"""删除存储"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace):
|
class BaseGraphStorage(StorageNameSpace):
|
||||||
|
"""图存储基类"""
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
"""检查节点是否存在"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
|
"""检查边是否存在"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
|
"""获取节点的度"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
|
"""获取边的度"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
|
"""获取节点信息"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
|
"""获取边信息"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_node_edges(
|
async def get_node_edges(
|
||||||
self, source_node_id: str
|
self, source_node_id: str
|
||||||
) -> Union[list[tuple[str, str]], None]:
|
) -> Union[list[tuple[str, str]], None]:
|
||||||
|
"""获取节点的所有边"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||||
|
"""更新或插入节点"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
):
|
||||||
|
"""更新或插入边"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str):
|
||||||
|
"""删除节点"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||||
|
"""节点嵌入(在lightrag中未使用)"""
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||||
|
@ -15,38 +15,65 @@ import xml.etree.ElementTree as ET
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
# 全局编码器变量
|
||||||
ENCODER = None
|
ENCODER = None
|
||||||
|
|
||||||
|
# 创建一个名为"lightrag"的日志记录器
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
|
|
||||||
|
|
||||||
def set_logger(log_file: str):
|
def set_logger(log_file: str):
|
||||||
|
"""设置日志记录器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file: 日志文件路径
|
||||||
|
"""
|
||||||
|
# 设置日志级别为DEBUG
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# 创建文件处理器
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# 设置日志格式:时间 - 名称 - 级别 - 消息
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
)
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# 如果logger没有处理器,则添加处理器
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
|
"""嵌入函数的包装类
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embedding_dim: 嵌入向量的维度
|
||||||
|
max_token_size: 最大token数量
|
||||||
|
func: 实际的嵌入函数
|
||||||
|
"""
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
max_token_size: int
|
max_token_size: int
|
||||||
func: callable
|
func: callable
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
|
"""使类实例可调用,直接调用内部的嵌入函数"""
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||||
"""Locate the JSON string body from a string"""
|
"""从字符串中定位JSON字符串主体
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 输入字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
找到的JSON字符串或None
|
||||||
|
"""
|
||||||
|
# 使用正则表达式查找{}包围的内容,DOTALL模式允许匹配跨行
|
||||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||||
if maybe_json_str is not None:
|
if maybe_json_str is not None:
|
||||||
return maybe_json_str.group(0)
|
return maybe_json_str.group(0)
|
||||||
@ -55,6 +82,18 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|||||||
|
|
||||||
|
|
||||||
def convert_response_to_json(response: str) -> dict:
|
def convert_response_to_json(response: str) -> dict:
|
||||||
|
"""将响应字符串转换为JSON对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: 响应字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的JSON字典
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: 无法从响应中解析JSON
|
||||||
|
JSONDecodeError: JSON解析失败
|
||||||
|
"""
|
||||||
json_str = locate_json_string_body_from_string(response)
|
json_str = locate_json_string_body_from_string(response)
|
||||||
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
||||||
try:
|
try:
|
||||||
@ -66,23 +105,48 @@ def convert_response_to_json(response: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args):
|
def compute_args_hash(*args):
|
||||||
|
"""计算参数的MD5哈希值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: 任意数量的参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
参数的MD5哈希值的十六进制字符串
|
||||||
|
"""
|
||||||
return md5(str(args).encode()).hexdigest()
|
return md5(str(args).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content, prefix: str = ""):
|
def compute_mdhash_id(content, prefix: str = ""):
|
||||||
|
"""计算内容的MD5哈希ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 要计算哈希的内容
|
||||||
|
prefix: 哈希值的前缀(默认为空)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
带前缀的MD5哈希值
|
||||||
|
"""
|
||||||
return prefix + md5(content.encode()).hexdigest()
|
return prefix + md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||||
"""Add restriction of maximum async calling times for a async func"""
|
"""限制异步函数的最大并发调用次数的装饰器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: 最大并发数
|
||||||
|
waitting_time: 等待时间间隔(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
装饰器函数
|
||||||
|
"""
|
||||||
def final_decro(func):
|
def final_decro(func):
|
||||||
"""Not using async.Semaphore to aovid use nest-asyncio"""
|
"""内部装饰器,不使用asyncio.Semaphore以避免使用nest-asyncio"""
|
||||||
__current_size = 0
|
__current_size = 0
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wait_func(*args, **kwargs):
|
async def wait_func(*args, **kwargs):
|
||||||
nonlocal __current_size
|
nonlocal __current_size
|
||||||
|
# 当当前并发数达到最大值时,等待
|
||||||
while __current_size >= max_size:
|
while __current_size >= max_size:
|
||||||
await asyncio.sleep(waitting_time)
|
await asyncio.sleep(waitting_time)
|
||||||
__current_size += 1
|
__current_size += 1
|
||||||
@ -96,8 +160,14 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
|||||||
|
|
||||||
|
|
||||||
def wrap_embedding_func_with_attrs(**kwargs):
|
def wrap_embedding_func_with_attrs(**kwargs):
|
||||||
"""Wrap a function with attributes"""
|
"""使用属性包装嵌入函数的装饰器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: 传递给EmbeddingFunc的关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
返回一个EmbeddingFunc实例的装饰器
|
||||||
|
"""
|
||||||
def final_decro(func) -> EmbeddingFunc:
|
def final_decro(func) -> EmbeddingFunc:
|
||||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||||
return new_func
|
return new_func
|
||||||
@ -106,6 +176,14 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def load_json(file_name):
|
def load_json(file_name):
|
||||||
|
"""从文件加载JSON数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: JSON文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载的JSON对象,如果文件不存在则返回None
|
||||||
|
"""
|
||||||
if not os.path.exists(file_name):
|
if not os.path.exists(file_name):
|
||||||
return None
|
return None
|
||||||
with open(file_name, encoding="utf-8") as f:
|
with open(file_name, encoding="utf-8") as f:
|
||||||
@ -113,11 +191,29 @@ def load_json(file_name):
|
|||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
|
"""将JSON对象写入文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_obj: 要写入的JSON对象
|
||||||
|
file_name: 目标文件路径
|
||||||
|
|
||||||
|
Note:
|
||||||
|
使用indent=2进行格式化,ensure_ascii=False支持中文
|
||||||
|
"""
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
||||||
|
"""使用tiktoken将字符串编码为tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 要编码的字符串
|
||||||
|
model_name: 使用的模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
编码后的token列表
|
||||||
|
"""
|
||||||
global ENCODER
|
global ENCODER
|
||||||
if ENCODER is None:
|
if ENCODER is None:
|
||||||
ENCODER = tiktoken.encoding_for_model(model_name)
|
ENCODER = tiktoken.encoding_for_model(model_name)
|
||||||
@ -126,6 +222,15 @@ def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
|||||||
|
|
||||||
|
|
||||||
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
||||||
|
"""使用tiktoken将tokens解码为字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: token列表
|
||||||
|
model_name: 使用的模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解码后的字符串
|
||||||
|
"""
|
||||||
global ENCODER
|
global ENCODER
|
||||||
if ENCODER is None:
|
if ENCODER is None:
|
||||||
ENCODER = tiktoken.encoding_for_model(model_name)
|
ENCODER = tiktoken.encoding_for_model(model_name)
|
||||||
@ -134,6 +239,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
|||||||
|
|
||||||
|
|
||||||
def pack_user_ass_to_openai_messages(*args: str):
|
def pack_user_ass_to_openai_messages(*args: str):
|
||||||
|
"""将用户和助手的对话打包成OpenAI消息格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: 交替的用户和助手消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI格式的消息列表,奇数位为用户消息,偶数位为助手消息
|
||||||
|
"""
|
||||||
roles = ["user", "assistant"]
|
roles = ["user", "assistant"]
|
||||||
return [
|
return [
|
||||||
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
||||||
@ -141,7 +254,15 @@ def pack_user_ass_to_openai_messages(*args: str):
|
|||||||
|
|
||||||
|
|
||||||
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
||||||
"""Split a string by multiple markers"""
|
"""使用多个标记分割字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 要分割的字符串
|
||||||
|
markers: 分割标记列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分割后的字符串列表,去除空白
|
||||||
|
"""
|
||||||
if not markers:
|
if not markers:
|
||||||
return [content]
|
return [content]
|
||||||
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
||||||
@ -151,22 +272,44 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
|
|||||||
# Refer the utils functions of the official GraphRAG implementation:
|
# Refer the utils functions of the official GraphRAG implementation:
|
||||||
# https://github.com/microsoft/graphrag
|
# https://github.com/microsoft/graphrag
|
||||||
def clean_str(input: Any) -> str:
|
def clean_str(input: Any) -> str:
|
||||||
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
"""清理字符串中的HTML转义字符和控制字符
|
||||||
# If we get non-string input, just give it back
|
|
||||||
|
Args:
|
||||||
|
input: 输入字符串或其他类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的字符串,如果输入不是字符串则原样返回
|
||||||
|
"""
|
||||||
if not isinstance(input, str):
|
if not isinstance(input, str):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
result = html.unescape(input.strip())
|
result = html.unescape(input.strip())
|
||||||
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
|
||||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||||
|
|
||||||
|
|
||||||
def is_float_regex(value):
|
def is_float_regex(value):
|
||||||
|
"""检查字符串是否为浮点数格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 要检查的值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否为浮点数格式的布尔值
|
||||||
|
"""
|
||||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
||||||
"""Truncate a list of data by token size"""
|
"""根据token大小截断列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
list_data: 要截断的列表
|
||||||
|
key: 从列表项中提取文本的函数
|
||||||
|
max_token_size: 最大token数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
截断后的列表
|
||||||
|
"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
return []
|
return []
|
||||||
tokens = 0
|
tokens = 0
|
||||||
@ -178,6 +321,14 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|||||||
|
|
||||||
|
|
||||||
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||||
|
"""将二维列表转换为CSV字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 二维字符串列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSV格式的字符串
|
||||||
|
"""
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
writer = csv.writer(output)
|
writer = csv.writer(output)
|
||||||
writer.writerows(data)
|
writer.writerows(data)
|
||||||
@ -185,65 +336,123 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
||||||
|
"""将CSV字符串转换为二维列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csv_string: CSV格式的字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
二维字符串列表
|
||||||
|
"""
|
||||||
output = io.StringIO(csv_string)
|
output = io.StringIO(csv_string)
|
||||||
reader = csv.reader(output)
|
reader = csv.reader(output)
|
||||||
return [row for row in reader]
|
return [row for row in reader]
|
||||||
|
|
||||||
|
|
||||||
def save_data_to_file(data, file_name):
|
def save_data_to_file(data, file_name):
|
||||||
|
"""将数据保存为JSON文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 要保存的数据
|
||||||
|
file_name: 目标文件路径
|
||||||
|
"""
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def xml_to_json(xml_file):
|
def xml_to_json(xml_file):
|
||||||
|
"""将GraphML格式的XML文件转换为JSON格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xml_file: GraphML文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含节点和边信息的字典,解析失败时返回None
|
||||||
|
|
||||||
|
Note:
|
||||||
|
转换后的数据结构为:
|
||||||
|
{
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": str,
|
||||||
|
"entity_type": str,
|
||||||
|
"description": str,
|
||||||
|
"source_id": str
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"source": str,
|
||||||
|
"target": str,
|
||||||
|
"weight": float,
|
||||||
|
"description": str,
|
||||||
|
"keywords": str,
|
||||||
|
"source_id": str
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 解析XML文件
|
||||||
tree = ET.parse(xml_file)
|
tree = ET.parse(xml_file)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
# Print the root element's tag and attributes to confirm the file has been correctly loaded
|
# 打印根元素信息以确认文件正确加载
|
||||||
print(f"Root element: {root.tag}")
|
print(f"Root element: {root.tag}")
|
||||||
print(f"Root attributes: {root.attrib}")
|
print(f"Root attributes: {root.attrib}")
|
||||||
|
|
||||||
|
# 初始化数据结构
|
||||||
data = {"nodes": [], "edges": []}
|
data = {"nodes": [], "edges": []}
|
||||||
|
|
||||||
# Use namespace
|
# 设置GraphML的命名空间
|
||||||
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
||||||
|
|
||||||
|
# 处理所有节点
|
||||||
for node in root.findall(".//node", namespace):
|
for node in root.findall(".//node", namespace):
|
||||||
node_data = {
|
node_data = {
|
||||||
|
# 获取节点ID并去除引号
|
||||||
"id": node.get("id").strip('"'),
|
"id": node.get("id").strip('"'),
|
||||||
|
# 获取实体类型,如果不存在则为空字符串
|
||||||
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
|
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
|
||||||
if node.find("./data[@key='d0']", namespace) is not None
|
if node.find("./data[@key='d0']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
|
# 获取描述信息
|
||||||
"description": node.find("./data[@key='d1']", namespace).text
|
"description": node.find("./data[@key='d1']", namespace).text
|
||||||
if node.find("./data[@key='d1']", namespace) is not None
|
if node.find("./data[@key='d1']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
|
# 获取源ID
|
||||||
"source_id": node.find("./data[@key='d2']", namespace).text
|
"source_id": node.find("./data[@key='d2']", namespace).text
|
||||||
if node.find("./data[@key='d2']", namespace) is not None
|
if node.find("./data[@key='d2']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
}
|
}
|
||||||
data["nodes"].append(node_data)
|
data["nodes"].append(node_data)
|
||||||
|
|
||||||
|
# 处理所有边
|
||||||
for edge in root.findall(".//edge", namespace):
|
for edge in root.findall(".//edge", namespace):
|
||||||
edge_data = {
|
edge_data = {
|
||||||
|
# 获取边的源节点和目标节点
|
||||||
"source": edge.get("source").strip('"'),
|
"source": edge.get("source").strip('"'),
|
||||||
"target": edge.get("target").strip('"'),
|
"target": edge.get("target").strip('"'),
|
||||||
|
# 获取权重,默认为0.0
|
||||||
"weight": float(edge.find("./data[@key='d3']", namespace).text)
|
"weight": float(edge.find("./data[@key='d3']", namespace).text)
|
||||||
if edge.find("./data[@key='d3']", namespace) is not None
|
if edge.find("./data[@key='d3']", namespace) is not None
|
||||||
else 0.0,
|
else 0.0,
|
||||||
|
# 获取描述信息
|
||||||
"description": edge.find("./data[@key='d4']", namespace).text
|
"description": edge.find("./data[@key='d4']", namespace).text
|
||||||
if edge.find("./data[@key='d4']", namespace) is not None
|
if edge.find("./data[@key='d4']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
|
# 获取关键词
|
||||||
"keywords": edge.find("./data[@key='d5']", namespace).text
|
"keywords": edge.find("./data[@key='d5']", namespace).text
|
||||||
if edge.find("./data[@key='d5']", namespace) is not None
|
if edge.find("./data[@key='d5']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
|
# 获取源ID
|
||||||
"source_id": edge.find("./data[@key='d6']", namespace).text
|
"source_id": edge.find("./data[@key='d6']", namespace).text
|
||||||
if edge.find("./data[@key='d6']", namespace) is not None
|
if edge.find("./data[@key='d6']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
}
|
}
|
||||||
data["edges"].append(edge_data)
|
data["edges"].append(edge_data)
|
||||||
|
|
||||||
# Print the number of nodes and edges found
|
# 打印统计信息
|
||||||
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
@ -256,31 +465,55 @@ def xml_to_json(xml_file):
|
|||||||
|
|
||||||
|
|
||||||
def process_combine_contexts(hl, ll):
|
def process_combine_contexts(hl, ll):
|
||||||
|
"""合并高层和低层上下文信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hl: 高层上下文的CSV字符串
|
||||||
|
ll: 低层上下文的CSV字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合并后的CSV格式字符串
|
||||||
|
|
||||||
|
Note:
|
||||||
|
处理步骤:
|
||||||
|
1. 解析输入的CSV字符串
|
||||||
|
2. 提取并保留表头
|
||||||
|
3. 合并数据行并去重
|
||||||
|
4. 重新格式化为CSV字符串
|
||||||
|
"""
|
||||||
|
# 初始化表头
|
||||||
header = None
|
header = None
|
||||||
|
# 解析CSV字符串
|
||||||
list_hl = csv_string_to_list(hl.strip())
|
list_hl = csv_string_to_list(hl.strip())
|
||||||
list_ll = csv_string_to_list(ll.strip())
|
list_ll = csv_string_to_list(ll.strip())
|
||||||
|
|
||||||
|
# 提取表头
|
||||||
if list_hl:
|
if list_hl:
|
||||||
header = list_hl[0]
|
header = list_hl[0]
|
||||||
list_hl = list_hl[1:]
|
list_hl = list_hl[1:] # 移除表头行
|
||||||
if list_ll:
|
if list_ll:
|
||||||
header = list_ll[0]
|
header = list_ll[0]
|
||||||
list_ll = list_ll[1:]
|
list_ll = list_ll[1:] # 移除表头行
|
||||||
if header is None:
|
if header is None:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# 处理数据行,只保留除第一列外的数据
|
||||||
if list_hl:
|
if list_hl:
|
||||||
list_hl = [",".join(item[1:]) for item in list_hl if item]
|
list_hl = [",".join(item[1:]) for item in list_hl if item]
|
||||||
if list_ll:
|
if list_ll:
|
||||||
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
||||||
|
|
||||||
|
# 合并数据并去重
|
||||||
combined_sources_set = set(filter(None, list_hl + list_ll))
|
combined_sources_set = set(filter(None, list_hl + list_ll))
|
||||||
|
|
||||||
combined_sources = [",\t".join(header)]
|
# 重新构建CSV字符串
|
||||||
|
combined_sources = [",\t".join(header)] # 添加表头
|
||||||
|
|
||||||
|
# 添加数据行,并加上新的序号
|
||||||
for i, item in enumerate(combined_sources_set, start=1):
|
for i, item in enumerate(combined_sources_set, start=1):
|
||||||
combined_sources.append(f"{i},\t{item}")
|
combined_sources.append(f"{i},\t{item}")
|
||||||
|
|
||||||
|
# 用换行符连接所有行
|
||||||
combined_sources = "\n".join(combined_sources)
|
combined_sources = "\n".join(combined_sources)
|
||||||
|
|
||||||
return combined_sources
|
return combined_sources
|
||||||
|
Loading…
Reference in New Issue
Block a user