使用cursor加注释。

This commit is contained in:
many2many 2024-11-16 11:40:45 +08:00
parent c8ee7286cb
commit 6969e4afc1
2 changed files with 302 additions and 26 deletions

View File

@ -5,119 +5,162 @@ import numpy as np
from .utils import EmbeddingFunc
# 定义文本块的数据结构包含令牌数、内容、完整文档ID和块序号
TextChunkSchema = TypedDict(
"TextChunkSchema",
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
)
# 定义泛型类型变量
T = TypeVar("T")
@dataclass
class QueryParam:
"""查询参数配置类
属性:
mode: 查询模式可选 'local'(局部)'global'(全局)'hybrid'(混合)'naive'(朴素)
only_need_context: 是否仅需要上下文
response_type: 响应类型
top_k: 检索的top-k项数量
max_token_for_text_unit: 原始文本块的最大令牌数
max_token_for_global_context: 关系描述的最大令牌数
max_token_for_local_context: 实体描述的最大令牌数
"""
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
max_token_for_local_context: int = 4000
@dataclass
class StorageNameSpace:
"""存储命名空间基类
属性:
namespace: 命名空间
global_config: 全局配置字典
"""
namespace: str
global_config: dict
async def index_done_callback(self):
"""commit the storage operations after indexing"""
"""索引完成后的回调函数,用于提交存储操作"""
pass
async def query_done_callback(self):
"""commit the storage operations after querying"""
"""查询完成后的回调函数,用于提交存储操作"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
"""向量存储基类
属性:
embedding_func: 嵌入函数
meta_fields: 元数据字段集合
"""
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
"""查询接口"""
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""更新或插入数据
使用value中的'content'字段进行嵌入使用key作为ID
如果embedding_func为None则使用value中的'embedding'字段
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
"""键值存储基类"""
async def all_keys(self) -> list[str]:
"""获取所有键"""
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
"""通过ID获取值"""
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
"""通过ID列表批量获取值"""
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
"""返回不存在的键集合"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
"""更新或插入数据"""
raise NotImplementedError
async def drop(self):
"""删除存储"""
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
"""图存储基类"""
async def has_node(self, node_id: str) -> bool:
"""检查节点是否存在"""
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""检查边是否存在"""
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
"""获取节点的度"""
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""获取边的度"""
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
"""获取节点信息"""
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""获取边信息"""
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
"""获取节点的所有边"""
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""更新或插入节点"""
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
"""更新或插入边"""
raise NotImplementedError
async def delete_node(self, node_id: str):
"""删除节点"""
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""节点嵌入在lightrag中未使用"""
raise NotImplementedError("Node embedding is not used in lightrag.")

View File

@ -15,38 +15,65 @@ import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
# 全局编码器变量
ENCODER = None
# 创建一个名为"lightrag"的日志记录器
logger = logging.getLogger("lightrag")
def set_logger(log_file: str):
"""设置日志记录器
Args:
log_file: 日志文件路径
"""
# 设置日志级别为DEBUG
logger.setLevel(logging.DEBUG)
# 创建文件处理器
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# 设置日志格式:时间 - 名称 - 级别 - 消息
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
# 如果logger没有处理器则添加处理器
if not logger.handlers:
logger.addHandler(file_handler)
@dataclass
class EmbeddingFunc:
"""嵌入函数的包装类
Attributes:
embedding_dim: 嵌入向量的维度
max_token_size: 最大token数量
func: 实际的嵌入函数
"""
embedding_dim: int
max_token_size: int
func: callable
async def __call__(self, *args, **kwargs) -> np.ndarray:
"""使类实例可调用,直接调用内部的嵌入函数"""
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
"""从字符串中定位JSON字符串主体
Args:
content: 输入字符串
Returns:
找到的JSON字符串或None
"""
# 使用正则表达式查找{}包围的内容DOTALL模式允许匹配跨行
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
return maybe_json_str.group(0)
@ -55,6 +82,18 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
def convert_response_to_json(response: str) -> dict:
"""将响应字符串转换为JSON对象
Args:
response: 响应字符串
Returns:
解析后的JSON字典
Raises:
AssertionError: 无法从响应中解析JSON
JSONDecodeError: JSON解析失败
"""
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
try:
@ -66,23 +105,48 @@ def convert_response_to_json(response: str) -> dict:
def compute_args_hash(*args):
"""计算参数的MD5哈希值
Args:
*args: 任意数量的参数
Returns:
参数的MD5哈希值的十六进制字符串
"""
return md5(str(args).encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""):
"""计算内容的MD5哈希ID
Args:
content: 要计算哈希的内容
prefix: 哈希值的前缀默认为空
Returns:
带前缀的MD5哈希值
"""
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
"""限制异步函数的最大并发调用次数的装饰器
Args:
max_size: 最大并发数
waitting_time: 等待时间间隔
Returns:
装饰器函数
"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
"""内部装饰器不使用asyncio.Semaphore以避免使用nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
# 当当前并发数达到最大值时,等待
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
@ -96,8 +160,14 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
"""使用属性包装嵌入函数的装饰器
Args:
**kwargs: 传递给EmbeddingFunc的关键字参数
Returns:
返回一个EmbeddingFunc实例的装饰器
"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
return new_func
@ -106,6 +176,14 @@ def wrap_embedding_func_with_attrs(**kwargs):
def load_json(file_name):
"""从文件加载JSON数据
Args:
file_name: JSON文件路径
Returns:
加载的JSON对象如果文件不存在则返回None
"""
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
@ -113,11 +191,29 @@ def load_json(file_name):
def write_json(json_obj, file_name):
"""将JSON对象写入文件
Args:
json_obj: 要写入的JSON对象
file_name: 目标文件路径
Note:
使用indent=2进行格式化ensure_ascii=False支持中文
"""
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
"""使用tiktoken将字符串编码为tokens
Args:
content: 要编码的字符串
model_name: 使用的模型名称
Returns:
编码后的token列表
"""
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
@ -126,6 +222,15 @@ def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
"""使用tiktoken将tokens解码为字符串
Args:
tokens: token列表
model_name: 使用的模型名称
Returns:
解码后的字符串
"""
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
@ -134,6 +239,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
def pack_user_ass_to_openai_messages(*args: str):
"""将用户和助手的对话打包成OpenAI消息格式
Args:
*args: 交替的用户和助手消息
Returns:
OpenAI格式的消息列表奇数位为用户消息偶数位为助手消息
"""
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
@ -141,7 +254,15 @@ def pack_user_ass_to_openai_messages(*args: str):
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
"""使用多个标记分割字符串
Args:
content: 要分割的字符串
markers: 分割标记列表
Returns:
分割后的字符串列表去除空白
"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
@ -151,22 +272,44 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
"""清理字符串中的HTML转义字符和控制字符
Args:
input: 输入字符串或其他类型
Returns:
清理后的字符串如果输入不是字符串则原样返回
"""
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value):
"""检查字符串是否为浮点数格式
Args:
value: 要检查的值
Returns:
是否为浮点数格式的布尔值
"""
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
"""根据token大小截断列表
Args:
list_data: 要截断的列表
key: 从列表项中提取文本的函数
max_token_size: 最大token数量
Returns:
截断后的列表
"""
if max_token_size <= 0:
return []
tokens = 0
@ -178,6 +321,14 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
def list_of_list_to_csv(data: List[List[str]]) -> str:
"""将二维列表转换为CSV字符串
Args:
data: 二维字符串列表
Returns:
CSV格式的字符串
"""
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
@ -185,65 +336,123 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
def csv_string_to_list(csv_string: str) -> List[List[str]]:
"""将CSV字符串转换为二维列表
Args:
csv_string: CSV格式的字符串
Returns:
二维字符串列表
"""
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def save_data_to_file(data, file_name):
"""将数据保存为JSON文件
Args:
data: 要保存的数据
file_name: 目标文件路径
"""
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
"""将GraphML格式的XML文件转换为JSON格式
Args:
xml_file: GraphML文件路径
Returns:
包含节点和边信息的字典解析失败时返回None
Note:
转换后的数据结构为:
{
"nodes": [
{
"id": str,
"entity_type": str,
"description": str,
"source_id": str
}
],
"edges": [
{
"source": str,
"target": str,
"weight": float,
"description": str,
"keywords": str,
"source_id": str
}
]
}
"""
try:
# 解析XML文件
tree = ET.parse(xml_file)
root = tree.getroot()
# Print the root element's tag and attributes to confirm the file has been correctly loaded
# 打印根元素信息以确认文件正确加载
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
# 初始化数据结构
data = {"nodes": [], "edges": []}
# Use namespace
# 设置GraphML的命名空间
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
# 处理所有节点
for node in root.findall(".//node", namespace):
node_data = {
# 获取节点ID并去除引号
"id": node.get("id").strip('"'),
# 获取实体类型,如果不存在则为空字符串
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
# 获取描述信息
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
# 获取源ID
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
# 处理所有边
for edge in root.findall(".//edge", namespace):
edge_data = {
# 获取边的源节点和目标节点
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
# 获取权重默认为0.0
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
# 获取描述信息
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
# 获取关键词
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
# 获取源ID
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
}
data["edges"].append(edge_data)
# Print the number of nodes and edges found
# 打印统计信息
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
@ -256,31 +465,55 @@ def xml_to_json(xml_file):
def process_combine_contexts(hl, ll):
"""合并高层和低层上下文信息
Args:
hl: 高层上下文的CSV字符串
ll: 低层上下文的CSV字符串
Returns:
合并后的CSV格式字符串
Note:
处理步骤
1. 解析输入的CSV字符串
2. 提取并保留表头
3. 合并数据行并去重
4. 重新格式化为CSV字符串
"""
# 初始化表头
header = None
# 解析CSV字符串
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
# 提取表头
if list_hl:
header = list_hl[0]
list_hl = list_hl[1:]
list_hl = list_hl[1:] # 移除表头行
if list_ll:
header = list_ll[0]
list_ll = list_ll[1:]
list_ll = list_ll[1:] # 移除表头行
if header is None:
return ""
# 处理数据行,只保留除第一列外的数据
if list_hl:
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [",".join(item[1:]) for item in list_ll if item]
# 合并数据并去重
combined_sources_set = set(filter(None, list_hl + list_ll))
combined_sources = [",\t".join(header)]
# 重新构建CSV字符串
combined_sources = [",\t".join(header)] # 添加表头
# 添加数据行,并加上新的序号
for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}")
# 用换行符连接所有行
combined_sources = "\n".join(combined_sources)
return combined_sources