lightrag-comments/lightrag/utils.py

530 lines
15 KiB
Python
Raw Permalink Normal View History

2024-11-16 11:59:20 +08:00
# 标准库导入 - 异步和IO操作
import asyncio # 异步IO支持
import html # HTML实体编解码
import io # 内存IO操作
import csv # CSV文件处理
import json # JSON数据处理
import logging # 日志记录
import os # 操作系统接口
import re # 正则表达式
# 标准库导入 - 数据结构和工具
from dataclasses import dataclass # 数据类装饰器
from functools import wraps # 装饰器工具
from hashlib import md5 # MD5哈希算法
from typing import ( # 类型提示
Any, # 任意类型
Union, # 联合类型
List # 列表类型
)
# XML处理
import xml.etree.ElementTree as ET # XML解析和处理
# 第三方库导入
import numpy as np # 数值计算库
import tiktoken # OpenAI的分词器
2024-11-16 11:40:45 +08:00
# 全局编码器变量
ENCODER = None
2024-11-16 11:40:45 +08:00
# 创建一个名为"lightrag"的日志记录器
logger = logging.getLogger("lightrag")
def set_logger(log_file: str):
2024-11-16 11:40:45 +08:00
"""设置日志记录器
Args:
log_file: 日志文件路径
"""
# 设置日志级别为DEBUG
logger.setLevel(logging.DEBUG)
2024-11-16 11:40:45 +08:00
# 创建文件处理器
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
2024-11-16 11:40:45 +08:00
# 设置日志格式:时间 - 名称 - 级别 - 消息
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
2024-11-16 11:40:45 +08:00
# 如果logger没有处理器则添加处理器
if not logger.handlers:
logger.addHandler(file_handler)
@dataclass
class EmbeddingFunc:
2024-11-16 11:40:45 +08:00
"""嵌入函数的包装类
Attributes:
embedding_dim: 嵌入向量的维度
max_token_size: 最大token数量
func: 实际的嵌入函数
"""
embedding_dim: int
max_token_size: int
func: callable
async def __call__(self, *args, **kwargs) -> np.ndarray:
2024-11-16 11:40:45 +08:00
"""使类实例可调用,直接调用内部的嵌入函数"""
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
2024-11-16 11:40:45 +08:00
"""从字符串中定位JSON字符串主体
Args:
content: 输入字符串
Returns:
找到的JSON字符串或None
"""
# 使用正则表达式查找{}包围的内容DOTALL模式允许匹配跨行
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
return maybe_json_str.group(0)
else:
return None
def convert_response_to_json(response: str) -> dict:
2024-11-16 11:40:45 +08:00
"""将响应字符串转换为JSON对象
Args:
response: 响应字符串
Returns:
解析后的JSON字典
Raises:
AssertionError: 无法从响应中解析JSON
JSONDecodeError: JSON解析失败
"""
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
try:
data = json.loads(json_str)
return data
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
def compute_args_hash(*args):
2024-11-16 11:40:45 +08:00
"""计算参数的MD5哈希值
Args:
*args: 任意数量的参数
Returns:
参数的MD5哈希值的十六进制字符串
"""
return md5(str(args).encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""):
2024-11-16 11:40:45 +08:00
"""计算内容的MD5哈希ID
Args:
content: 要计算哈希的内容
prefix: 哈希值的前缀默认为空
Returns:
带前缀的MD5哈希值
"""
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
2024-11-16 11:40:45 +08:00
"""限制异步函数的最大并发调用次数的装饰器
Args:
max_size: 最大并发数
waitting_time: 等待时间间隔
Returns:
装饰器函数
"""
def final_decro(func):
2024-11-16 11:40:45 +08:00
"""内部装饰器不使用asyncio.Semaphore以避免使用nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
2024-11-16 11:40:45 +08:00
# 当当前并发数达到最大值时,等待
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
2024-11-16 11:40:45 +08:00
"""使用属性包装嵌入函数的装饰器
Args:
**kwargs: 传递给EmbeddingFunc的关键字参数
Returns:
返回一个EmbeddingFunc实例的装饰器
"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
return new_func
return final_decro
def load_json(file_name):
2024-11-16 11:40:45 +08:00
"""从文件加载JSON数据
Args:
file_name: JSON文件路径
Returns:
加载的JSON对象如果文件不存在则返回None
"""
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
def write_json(json_obj, file_name):
2024-11-16 11:40:45 +08:00
"""将JSON对象写入文件
Args:
json_obj: 要写入的JSON对象
file_name: 目标文件路径
Note:
使用indent=2进行格式化ensure_ascii=False支持中文
"""
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
2024-11-16 11:40:45 +08:00
"""使用tiktoken将字符串编码为tokens
Args:
content: 要编码的字符串
model_name: 使用的模型名称
Returns:
编码后的token列表
"""
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
2024-11-16 11:40:45 +08:00
"""使用tiktoken将tokens解码为字符串
Args:
tokens: token列表
model_name: 使用的模型名称
Returns:
解码后的字符串
"""
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
def pack_user_ass_to_openai_messages(*args: str):
2024-11-16 11:40:45 +08:00
"""将用户和助手的对话打包成OpenAI消息格式
Args:
*args: 交替的用户和助手消息
Returns:
OpenAI格式的消息列表奇数位为用户消息偶数位为助手消息
"""
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
2024-11-16 11:40:45 +08:00
"""使用多个标记分割字符串
Args:
content: 要分割的字符串
markers: 分割标记列表
Returns:
分割后的字符串列表去除空白
"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
2024-11-16 11:40:45 +08:00
"""清理字符串中的HTML转义字符和控制字符
Args:
input: 输入字符串或其他类型
Returns:
清理后的字符串如果输入不是字符串则原样返回
"""
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value):
2024-11-16 11:40:45 +08:00
"""检查字符串是否为浮点数格式
Args:
value: 要检查的值
Returns:
是否为浮点数格式的布尔值
"""
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
2024-11-16 11:40:45 +08:00
"""根据token大小截断列表
Args:
list_data: 要截断的列表
key: 从列表项中提取文本的函数
max_token_size: 最大token数量
Returns:
截断后的列表
"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
2024-11-16 11:40:45 +08:00
"""将二维列表转换为CSV字符串
Args:
data: 二维字符串列表
Returns:
CSV格式的字符串
"""
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
2024-11-16 11:40:45 +08:00
"""将CSV字符串转换为二维列表
Args:
csv_string: CSV格式的字符串
Returns:
二维字符串列表
"""
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def save_data_to_file(data, file_name):
2024-11-16 11:40:45 +08:00
"""将数据保存为JSON文件
Args:
data: 要保存的数据
file_name: 目标文件路径
"""
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
2024-11-16 11:40:45 +08:00
"""将GraphML格式的XML文件转换为JSON格式
Args:
xml_file: GraphML文件路径
Returns:
包含节点和边信息的字典解析失败时返回None
Note:
转换后的数据结构为:
{
"nodes": [
{
"id": str,
"entity_type": str,
"description": str,
"source_id": str
}
],
"edges": [
{
"source": str,
"target": str,
"weight": float,
"description": str,
"keywords": str,
"source_id": str
}
]
}
"""
try:
2024-11-16 11:40:45 +08:00
# 解析XML文件
tree = ET.parse(xml_file)
root = tree.getroot()
2024-11-16 11:40:45 +08:00
# 打印根元素信息以确认文件正确加载
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
2024-11-16 11:40:45 +08:00
# 初始化数据结构
data = {"nodes": [], "edges": []}
2024-11-16 11:40:45 +08:00
# 设置GraphML的命名空间
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
2024-11-16 11:40:45 +08:00
# 处理所有节点
for node in root.findall(".//node", namespace):
node_data = {
2024-11-16 11:40:45 +08:00
# 获取节点ID并去除引号
"id": node.get("id").strip('"'),
2024-11-16 11:40:45 +08:00
# 获取实体类型,如果不存在则为空字符串
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
2024-11-16 11:40:45 +08:00
# 获取描述信息
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
2024-11-16 11:40:45 +08:00
# 获取源ID
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
2024-11-16 11:40:45 +08:00
# 处理所有边
for edge in root.findall(".//edge", namespace):
edge_data = {
2024-11-16 11:40:45 +08:00
# 获取边的源节点和目标节点
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
2024-11-16 11:40:45 +08:00
# 获取权重默认为0.0
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
2024-11-16 11:40:45 +08:00
# 获取描述信息
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
2024-11-16 11:40:45 +08:00
# 获取关键词
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
2024-11-16 11:40:45 +08:00
# 获取源ID
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
}
data["edges"].append(edge_data)
2024-11-16 11:40:45 +08:00
# 打印统计信息
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
def process_combine_contexts(hl, ll):
2024-11-16 11:40:45 +08:00
"""合并高层和低层上下文信息
Args:
hl: 高层上下文的CSV字符串
ll: 低层上下文的CSV字符串
Returns:
合并后的CSV格式字符串
Note:
处理步骤
1. 解析输入的CSV字符串
2. 提取并保留表头
3. 合并数据行并去重
4. 重新格式化为CSV字符串
"""
# 初始化表头
header = None
2024-11-16 11:40:45 +08:00
# 解析CSV字符串
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
2024-11-16 11:40:45 +08:00
# 提取表头
if list_hl:
header = list_hl[0]
2024-11-16 11:40:45 +08:00
list_hl = list_hl[1:] # 移除表头行
if list_ll:
header = list_ll[0]
2024-11-16 11:40:45 +08:00
list_ll = list_ll[1:] # 移除表头行
if header is None:
return ""
2024-11-16 11:40:45 +08:00
# 处理数据行,只保留除第一列外的数据
if list_hl:
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [",".join(item[1:]) for item in list_ll if item]
2024-11-16 11:40:45 +08:00
# 合并数据并去重
combined_sources_set = set(filter(None, list_hl + list_ll))
2024-11-16 11:40:45 +08:00
# 重新构建CSV字符串
combined_sources = [",\t".join(header)] # 添加表头
2024-11-16 11:40:45 +08:00
# 添加数据行,并加上新的序号
for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}")
2024-11-16 11:40:45 +08:00
# 用换行符连接所有行
combined_sources = "\n".join(combined_sources)
return combined_sources