# 标准库导入 - 异步和IO操作 import asyncio # 异步IO支持 import html # HTML实体编解码 import io # 内存IO操作 import csv # CSV文件处理 import json # JSON数据处理 import logging # 日志记录 import os # 操作系统接口 import re # 正则表达式 # 标准库导入 - 数据结构和工具 from dataclasses import dataclass # 数据类装饰器 from functools import wraps # 装饰器工具 from hashlib import md5 # MD5哈希算法 from typing import ( # 类型提示 Any, # 任意类型 Union, # 联合类型 List # 列表类型 ) # XML处理 import xml.etree.ElementTree as ET # XML解析和处理 # 第三方库导入 import numpy as np # 数值计算库 import tiktoken # OpenAI的分词器 # 全局编码器变量 ENCODER = None # 创建一个名为"lightrag"的日志记录器 logger = logging.getLogger("lightrag") def set_logger(log_file: str): """设置日志记录器 Args: log_file: 日志文件路径 """ # 设置日志级别为DEBUG logger.setLevel(logging.DEBUG) # 创建文件处理器 file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) # 设置日志格式:时间 - 名称 - 级别 - 消息 formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) file_handler.setFormatter(formatter) # 如果logger没有处理器,则添加处理器 if not logger.handlers: logger.addHandler(file_handler) @dataclass class EmbeddingFunc: """嵌入函数的包装类 Attributes: embedding_dim: 嵌入向量的维度 max_token_size: 最大token数量 func: 实际的嵌入函数 """ embedding_dim: int max_token_size: int func: callable async def __call__(self, *args, **kwargs) -> np.ndarray: """使类实例可调用,直接调用内部的嵌入函数""" return await self.func(*args, **kwargs) def locate_json_string_body_from_string(content: str) -> Union[str, None]: """从字符串中定位JSON字符串主体 Args: content: 输入字符串 Returns: 找到的JSON字符串或None """ # 使用正则表达式查找{}包围的内容,DOTALL模式允许匹配跨行 maybe_json_str = re.search(r"{.*}", content, re.DOTALL) if maybe_json_str is not None: return maybe_json_str.group(0) else: return None def convert_response_to_json(response: str) -> dict: """将响应字符串转换为JSON对象 Args: response: 响应字符串 Returns: 解析后的JSON字典 Raises: AssertionError: 无法从响应中解析JSON JSONDecodeError: JSON解析失败 """ json_str = locate_json_string_body_from_string(response) assert json_str is not None, f"Unable to parse JSON from response: {response}" try: data = json.loads(json_str) return data except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON: {json_str}") raise e from None def compute_args_hash(*args): """计算参数的MD5哈希值 Args: *args: 任意数量的参数 Returns: 参数的MD5哈希值的十六进制字符串 """ return md5(str(args).encode()).hexdigest() def compute_mdhash_id(content, prefix: str = ""): """计算内容的MD5哈希ID Args: content: 要计算哈希的内容 prefix: 哈希值的前缀(默认为空) Returns: 带前缀的MD5哈希值 """ return prefix + md5(content.encode()).hexdigest() def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): """限制异步函数的最大并发调用次数的装饰器 Args: max_size: 最大并发数 waitting_time: 等待时间间隔(秒) Returns: 装饰器函数 """ def final_decro(func): """内部装饰器,不使用asyncio.Semaphore以避免使用nest-asyncio""" __current_size = 0 @wraps(func) async def wait_func(*args, **kwargs): nonlocal __current_size # 当当前并发数达到最大值时,等待 while __current_size >= max_size: await asyncio.sleep(waitting_time) __current_size += 1 result = await func(*args, **kwargs) __current_size -= 1 return result return wait_func return final_decro def wrap_embedding_func_with_attrs(**kwargs): """使用属性包装嵌入函数的装饰器 Args: **kwargs: 传递给EmbeddingFunc的关键字参数 Returns: 返回一个EmbeddingFunc实例的装饰器 """ def final_decro(func) -> EmbeddingFunc: new_func = EmbeddingFunc(**kwargs, func=func) return new_func return final_decro def load_json(file_name): """从文件加载JSON数据 Args: file_name: JSON文件路径 Returns: 加载的JSON对象,如果文件不存在则返回None """ if not os.path.exists(file_name): return None with open(file_name, encoding="utf-8") as f: return json.load(f) def write_json(json_obj, file_name): """将JSON对象写入文件 Args: json_obj: 要写入的JSON对象 file_name: 目标文件路径 Note: 使用indent=2进行格式化,ensure_ascii=False支持中文 """ with open(file_name, "w", encoding="utf-8") as f: json.dump(json_obj, f, indent=2, ensure_ascii=False) def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): """使用tiktoken将字符串编码为tokens Args: content: 要编码的字符串 model_name: 使用的模型名称 Returns: 编码后的token列表 """ global ENCODER if ENCODER is None: ENCODER = tiktoken.encoding_for_model(model_name) tokens = ENCODER.encode(content) return tokens def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"): """使用tiktoken将tokens解码为字符串 Args: tokens: token列表 model_name: 使用的模型名称 Returns: 解码后的字符串 """ global ENCODER if ENCODER is None: ENCODER = tiktoken.encoding_for_model(model_name) content = ENCODER.decode(tokens) return content def pack_user_ass_to_openai_messages(*args: str): """将用户和助手的对话打包成OpenAI消息格式 Args: *args: 交替的用户和助手消息 Returns: OpenAI格式的消息列表,奇数位为用户消息,偶数位为助手消息 """ roles = ["user", "assistant"] return [ {"role": roles[i % 2], "content": content} for i, content in enumerate(args) ] def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: """使用多个标记分割字符串 Args: content: 要分割的字符串 markers: 分割标记列表 Returns: 分割后的字符串列表,去除空白 """ if not markers: return [content] results = re.split("|".join(re.escape(marker) for marker in markers), content) return [r.strip() for r in results if r.strip()] # Refer the utils functions of the official GraphRAG implementation: # https://github.com/microsoft/graphrag def clean_str(input: Any) -> str: """清理字符串中的HTML转义字符和控制字符 Args: input: 输入字符串或其他类型 Returns: 清理后的字符串,如果输入不是字符串则原样返回 """ if not isinstance(input, str): return input result = html.unescape(input.strip()) return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) def is_float_regex(value): """检查字符串是否为浮点数格式 Args: value: 要检查的值 Returns: 是否为浮点数格式的布尔值 """ return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): """根据token大小截断列表 Args: list_data: 要截断的列表 key: 从列表项中提取文本的函数 max_token_size: 最大token数量 Returns: 截断后的列表 """ if max_token_size <= 0: return [] tokens = 0 for i, data in enumerate(list_data): tokens += len(encode_string_by_tiktoken(key(data))) if tokens > max_token_size: return list_data[:i] return list_data def list_of_list_to_csv(data: List[List[str]]) -> str: """将二维列表转换为CSV字符串 Args: data: 二维字符串列表 Returns: CSV格式的字符串 """ output = io.StringIO() writer = csv.writer(output) writer.writerows(data) return output.getvalue() def csv_string_to_list(csv_string: str) -> List[List[str]]: """将CSV字符串转换为二维列表 Args: csv_string: CSV格式的字符串 Returns: 二维字符串列表 """ output = io.StringIO(csv_string) reader = csv.reader(output) return [row for row in reader] def save_data_to_file(data, file_name): """将数据保存为JSON文件 Args: data: 要保存的数据 file_name: 目标文件路径 """ with open(file_name, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) def xml_to_json(xml_file): """将GraphML格式的XML文件转换为JSON格式 Args: xml_file: GraphML文件路径 Returns: 包含节点和边信息的字典,解析失败时返回None Note: 转换后的数据结构为: { "nodes": [ { "id": str, "entity_type": str, "description": str, "source_id": str } ], "edges": [ { "source": str, "target": str, "weight": float, "description": str, "keywords": str, "source_id": str } ] } """ try: # 解析XML文件 tree = ET.parse(xml_file) root = tree.getroot() # 打印根元素信息以确认文件正确加载 print(f"Root element: {root.tag}") print(f"Root attributes: {root.attrib}") # 初始化数据结构 data = {"nodes": [], "edges": []} # 设置GraphML的命名空间 namespace = {"": "http://graphml.graphdrawing.org/xmlns"} # 处理所有节点 for node in root.findall(".//node", namespace): node_data = { # 获取节点ID并去除引号 "id": node.get("id").strip('"'), # 获取实体类型,如果不存在则为空字符串 "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "", # 获取描述信息 "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "", # 获取源ID "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else "", } data["nodes"].append(node_data) # 处理所有边 for edge in root.findall(".//edge", namespace): edge_data = { # 获取边的源节点和目标节点 "source": edge.get("source").strip('"'), "target": edge.get("target").strip('"'), # 获取权重,默认为0.0 "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0, # 获取描述信息 "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "", # 获取关键词 "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "", # 获取源ID "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else "", } data["edges"].append(edge_data) # 打印统计信息 print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") return data except ET.ParseError as e: print(f"Error parsing XML file: {e}") return None except Exception as e: print(f"An error occurred: {e}") return None def process_combine_contexts(hl, ll): """合并高层和低层上下文信息 Args: hl: 高层上下文的CSV字符串 ll: 低层上下文的CSV字符串 Returns: 合并后的CSV格式字符串 Note: 处理步骤: 1. 解析输入的CSV字符串 2. 提取并保留表头 3. 合并数据行并去重 4. 重新格式化为CSV字符串 """ # 初始化表头 header = None # 解析CSV字符串 list_hl = csv_string_to_list(hl.strip()) list_ll = csv_string_to_list(ll.strip()) # 提取表头 if list_hl: header = list_hl[0] list_hl = list_hl[1:] # 移除表头行 if list_ll: header = list_ll[0] list_ll = list_ll[1:] # 移除表头行 if header is None: return "" # 处理数据行,只保留除第一列外的数据 if list_hl: list_hl = [",".join(item[1:]) for item in list_hl if item] if list_ll: list_ll = [",".join(item[1:]) for item in list_ll if item] # 合并数据并去重 combined_sources_set = set(filter(None, list_hl + list_ll)) # 重新构建CSV字符串 combined_sources = [",\t".join(header)] # 添加表头 # 添加数据行,并加上新的序号 for i, item in enumerate(combined_sources_set, start=1): combined_sources.append(f"{i},\t{item}") # 用换行符连接所有行 combined_sources = "\n".join(combined_sources) return combined_sources