lightrag-comments/lightrag/utils.py

520 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import html
import io
import csv
import json
import logging
import os
import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union, List
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
# 全局编码器变量
ENCODER = None
# 创建一个名为"lightrag"的日志记录器
logger = logging.getLogger("lightrag")
def set_logger(log_file: str):
"""设置日志记录器
Args:
log_file: 日志文件路径
"""
# 设置日志级别为DEBUG
logger.setLevel(logging.DEBUG)
# 创建文件处理器
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# 设置日志格式:时间 - 名称 - 级别 - 消息
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
# 如果logger没有处理器则添加处理器
if not logger.handlers:
logger.addHandler(file_handler)
@dataclass
class EmbeddingFunc:
"""嵌入函数的包装类
Attributes:
embedding_dim: 嵌入向量的维度
max_token_size: 最大token数量
func: 实际的嵌入函数
"""
embedding_dim: int
max_token_size: int
func: callable
async def __call__(self, *args, **kwargs) -> np.ndarray:
"""使类实例可调用,直接调用内部的嵌入函数"""
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""从字符串中定位JSON字符串主体
Args:
content: 输入字符串
Returns:
找到的JSON字符串或None
"""
# 使用正则表达式查找{}包围的内容DOTALL模式允许匹配跨行
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
return maybe_json_str.group(0)
else:
return None
def convert_response_to_json(response: str) -> dict:
"""将响应字符串转换为JSON对象
Args:
response: 响应字符串
Returns:
解析后的JSON字典
Raises:
AssertionError: 无法从响应中解析JSON
JSONDecodeError: JSON解析失败
"""
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
try:
data = json.loads(json_str)
return data
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
def compute_args_hash(*args):
"""计算参数的MD5哈希值
Args:
*args: 任意数量的参数
Returns:
参数的MD5哈希值的十六进制字符串
"""
return md5(str(args).encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""):
"""计算内容的MD5哈希ID
Args:
content: 要计算哈希的内容
prefix: 哈希值的前缀(默认为空)
Returns:
带前缀的MD5哈希值
"""
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""限制异步函数的最大并发调用次数的装饰器
Args:
max_size: 最大并发数
waitting_time: 等待时间间隔(秒)
Returns:
装饰器函数
"""
def final_decro(func):
"""内部装饰器不使用asyncio.Semaphore以避免使用nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
# 当当前并发数达到最大值时,等待
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
"""使用属性包装嵌入函数的装饰器
Args:
**kwargs: 传递给EmbeddingFunc的关键字参数
Returns:
返回一个EmbeddingFunc实例的装饰器
"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
return new_func
return final_decro
def load_json(file_name):
"""从文件加载JSON数据
Args:
file_name: JSON文件路径
Returns:
加载的JSON对象如果文件不存在则返回None
"""
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
def write_json(json_obj, file_name):
"""将JSON对象写入文件
Args:
json_obj: 要写入的JSON对象
file_name: 目标文件路径
Note:
使用indent=2进行格式化ensure_ascii=False支持中文
"""
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
"""使用tiktoken将字符串编码为tokens
Args:
content: 要编码的字符串
model_name: 使用的模型名称
Returns:
编码后的token列表
"""
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
"""使用tiktoken将tokens解码为字符串
Args:
tokens: token列表
model_name: 使用的模型名称
Returns:
解码后的字符串
"""
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
def pack_user_ass_to_openai_messages(*args: str):
"""将用户和助手的对话打包成OpenAI消息格式
Args:
*args: 交替的用户和助手消息
Returns:
OpenAI格式的消息列表奇数位为用户消息偶数位为助手消息
"""
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""使用多个标记分割字符串
Args:
content: 要分割的字符串
markers: 分割标记列表
Returns:
分割后的字符串列表,去除空白
"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
"""清理字符串中的HTML转义字符和控制字符
Args:
input: 输入字符串或其他类型
Returns:
清理后的字符串,如果输入不是字符串则原样返回
"""
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value):
"""检查字符串是否为浮点数格式
Args:
value: 要检查的值
Returns:
是否为浮点数格式的布尔值
"""
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""根据token大小截断列表
Args:
list_data: 要截断的列表
key: 从列表项中提取文本的函数
max_token_size: 最大token数量
Returns:
截断后的列表
"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
"""将二维列表转换为CSV字符串
Args:
data: 二维字符串列表
Returns:
CSV格式的字符串
"""
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
"""将CSV字符串转换为二维列表
Args:
csv_string: CSV格式的字符串
Returns:
二维字符串列表
"""
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def save_data_to_file(data, file_name):
"""将数据保存为JSON文件
Args:
data: 要保存的数据
file_name: 目标文件路径
"""
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
"""将GraphML格式的XML文件转换为JSON格式
Args:
xml_file: GraphML文件路径
Returns:
包含节点和边信息的字典解析失败时返回None
Note:
转换后的数据结构为:
{
"nodes": [
{
"id": str,
"entity_type": str,
"description": str,
"source_id": str
}
],
"edges": [
{
"source": str,
"target": str,
"weight": float,
"description": str,
"keywords": str,
"source_id": str
}
]
}
"""
try:
# 解析XML文件
tree = ET.parse(xml_file)
root = tree.getroot()
# 打印根元素信息以确认文件正确加载
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
# 初始化数据结构
data = {"nodes": [], "edges": []}
# 设置GraphML的命名空间
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
# 处理所有节点
for node in root.findall(".//node", namespace):
node_data = {
# 获取节点ID并去除引号
"id": node.get("id").strip('"'),
# 获取实体类型,如果不存在则为空字符串
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
# 获取描述信息
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
# 获取源ID
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
# 处理所有边
for edge in root.findall(".//edge", namespace):
edge_data = {
# 获取边的源节点和目标节点
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
# 获取权重默认为0.0
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
# 获取描述信息
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
# 获取关键词
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
# 获取源ID
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
}
data["edges"].append(edge_data)
# 打印统计信息
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
def process_combine_contexts(hl, ll):
"""合并高层和低层上下文信息
Args:
hl: 高层上下文的CSV字符串
ll: 低层上下文的CSV字符串
Returns:
合并后的CSV格式字符串
Note:
处理步骤:
1. 解析输入的CSV字符串
2. 提取并保留表头
3. 合并数据行并去重
4. 重新格式化为CSV字符串
"""
# 初始化表头
header = None
# 解析CSV字符串
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
# 提取表头
if list_hl:
header = list_hl[0]
list_hl = list_hl[1:] # 移除表头行
if list_ll:
header = list_ll[0]
list_ll = list_ll[1:] # 移除表头行
if header is None:
return ""
# 处理数据行,只保留除第一列外的数据
if list_hl:
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [",".join(item[1:]) for item in list_ll if item]
# 合并数据并去重
combined_sources_set = set(filter(None, list_hl + list_ll))
# 重新构建CSV字符串
combined_sources = [",\t".join(header)] # 添加表头
# 添加数据行,并加上新的序号
for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}")
# 用换行符连接所有行
combined_sources = "\n".join(combined_sources)
return combined_sources