import asyncio
import json
import re
from typing import Union
from collections import Counter, defaultdict
import warnings
from .utils import (
logger,
clean_str,
compute_mdhash_id,
decode_tokens_by_tiktoken,
encode_string_by_tiktoken,
is_float_regex,
list_of_list_to_csv,
pack_user_ass_to_openai_messages,
split_string_by_multi_markers,
truncate_list_by_token_size,
process_combine_contexts,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
TextChunkSchema,
QueryParam,
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
def chunking_by_token_size(
content: str,
overlap_token_size=128, # 重叠部分的token数量
max_token_size=1024, # 每个chunk的最大token数量
tiktoken_model="gpt-4o" # 使用的tokenizer模型名称
):
"""将长文本按照token数量切分成多个重叠的chunk
Args:
content (str): 需要切分的原始文本内容
overlap_token_size (int, optional): 相邻chunk之间的重叠token数. Defaults to 128.
max_token_size (int, optional): 每个chunk的最大token数. Defaults to 1024.
tiktoken_model (str, optional): 使用的tokenizer模型. Defaults to "gpt-4o".
Returns:
list[dict]: 包含切分后的chunk列表,每个chunk包含以下字段:
- tokens: chunk实际包含的token数
- content: chunk的文本内容
- chunk_order_index: chunk的序号
"""
# 使用指定的tokenizer模型将文本编码为tokens
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
# 按照指定的步长(max_token_size - overlap_token_size)遍历tokens
#
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
# 解码当前窗口范围内的tokens为文本
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size],
model_name=tiktoken_model
)
# 将chunk添加到结果列表中
results.append(
{
# 计算实际token数量(最后一个chunk可能不足max_token_size)
"tokens": min(max_token_size, len(tokens) - start),
# 去除chunk首尾的空白字符
"content": chunk_content.strip(),
# 记录chunk的序号
"chunk_order_index": index,
}
)
return results
async def _handle_entity_relation_summary(
entity_or_relation_name: str, # 实体名称或关系名称
description: str, # 需要总结的描述文本
global_config: dict, # 全局配置参数
) -> str:
"""对实体或关系的描述文本进行总结
当描述文本的token数量超过指定阈值时,使用LLM对其进行总结,以减少token数量
Args:
entity_or_relation_name (str): 实体名称或关系名称(用于提示LLM)
description (str): 需要总结的描述文本
global_config (dict): 包含LLM配置的全局参数字典,必须包含:
- llm_model_func: LLM调用函数
- llm_model_max_token_size: LLM输入的最大token数
- tiktoken_model_name: 使用的tokenizer模型名称
- entity_summary_to_max_tokens: 总结后的最大token数
Returns:
str: 如果原文本较短则直接返回,否则返回LLM总结后的文本
"""
# 从全局配置中获取必要的参数
use_llm_func: callable = global_config["llm_model_func"]
llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
# 计算描述文本的token数量
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
# 如果token数量小于阈值,无需总结直接返回
if len(tokens) < summary_max_tokens:
return description
# 获取总结用的提示模板
prompt_template = PROMPTS["summarize_entity_descriptions"]
# 截取LLM能处理的最大token数量
use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens],
model_name=tiktoken_model_name
)
# 构建提示上下文
context_base = dict(
entity_name=entity_or_relation_name,
# 按字段分隔符分割描述文本
description_list=use_description.split(GRAPH_FIELD_SEP),
)
# 格式化提示模板
use_prompt = prompt_template.format(**context_base)
# 记录调试日志
logger.debug(f"Trigger summary: {entity_or_relation_name}")
# 调用LLM生成总结
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
return summary
async def _handle_single_entity_extraction(
record_attributes: list[str], # 记录的属性列表
chunk_key: str, # 当前chunk的唯一标识符
):
"""从记录属性中提取单个实体的信息
Args:
record_attributes (list[str]): 包含实体信息的属性列表
chunk_key (str): 当前chunk的唯一标识符,用于追踪数据来源
Returns:
dict: 包含实体信息的字典,如果记录无效则返回None
"""
# 检查记录是否为有效的实体记录
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# 提取并清理实体名称
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip(): # 如果实体名称为空,返回None
return None
# 提取并清理实体类型和描述
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key # 记录数据来源
# 返回包含实体信息的字典
return dict(
entity_name=entity_name,
entity_type=entity_type,
description=entity_description,
source_id=entity_source_id,
)
async def _handle_single_relationship_extraction(
record_attributes: list[str], # 记录的属性列表
chunk_key: str, # 当前chunk的唯一标识符
):
"""从记录属性中提取单个关系的信息
Args:
record_attributes (list[str]): 包含关系信息的属性列表
chunk_key (str): 当前chunk的唯一标识符,用于追踪数据来源
Returns:
dict: 包含关系信息的字典,如果记录无效则返回None
"""
# 检查记录是否为有效的关系记录
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# 提取并清理关系的源节点和目标节点
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
# 提取并清理关系的描述和关键词
edge_description = clean_str(record_attributes[3])
edge_keywords = clean_str(record_attributes[4])
edge_source_id = chunk_key # 记录数据来源
# 提取并计算关系的权重,默认为1.0
weight = (
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
)
# 返回包含关系信息的字典
return dict(
src_id=source,
tgt_id=target,
weight=weight,
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
)
async def _merge_nodes_then_upsert(
entity_name: str, # 实体名称
nodes_data: list[dict], # 待合并的节点数据列表
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
global_config: dict, # 全局配置参数
):
"""合并节点数据并更新到知识图谱中
Args:
entity_name (str): 实体名称
nodes_data (list[dict]): 待合并的节点数据列表
knowledge_graph_inst (BaseGraphStorage): 知识图谱实例
global_config (dict): 全局配置参数
Returns:
dict: 更新后的节点数据
"""
already_entitiy_types = []
already_source_ids = []
already_description = []
# 检查知识图谱中是否已存在该节点
already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
# 如果存在,提取已有的实体类型、来源ID和描述
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_node["description"])
# 计算合并后的实体类型(选择出现频率最高的类型)
entity_type = sorted(
Counter(
[dp["entity_type"] for dp in nodes_data] + already_entitiy_types
).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
# 合并描述和来源ID
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in nodes_data] + already_description))
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
)
# 使用LLM对合并后的描述进行总结
description = await _handle_entity_relation_summary(
entity_name, description, global_config
)
# 构建节点数据字典
node_data = dict(
entity_type=entity_type,
description=description,
source_id=source_id,
)
# 更新或插入节点到知识图谱中
await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
node_data["entity_name"] = entity_name
return node_data
async def _merge_edges_then_upsert(
src_id: str, # 源节点ID
tgt_id: str, # 目标节点ID
edges_data: list[dict], # 待合并的边数据列表
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
global_config: dict, # 全局配置参数
):
"""合并边数据并更新到知识图谱中
Args:
src_id (str): 源节点ID
tgt_id (str): 目标节点ID
edges_data (list[dict]): 待合并的边数据列表
knowledge_graph_inst (BaseGraphStorage): 知识图谱实例
global_config (dict): 全局配置参数
Returns:
dict: 更新后的边数据
"""
already_weights = []
already_source_ids = []
already_description = []
already_keywords = []
# 检查知识图谱中是否已存在该边
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
# 如果存在,提取已有的权重、来源ID、描述和关键词
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_edge["description"])
already_keywords.extend(
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
)
# 计算合并后的权重(累加)
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
# 合并描述、关键词和来源ID
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
)
keywords = GRAPH_FIELD_SEP.join(
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
# 确保源节点和目标节点存在于知识图谱中
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": '"UNKNOWN"',
},
)
# 使用LLM对合并后的描述进行总结
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config
)
# 更新或插入边到知识图谱中
await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
weight=weight,
description=description,
keywords=keywords,
source_id=source_id,
),
)
edge_data = dict(
src_id=src_id,
tgt_id=tgt_id,
description=description,
keywords=keywords,
)
return edge_data
async def extract_entities(
chunks: dict[str, TextChunkSchema], # 文本块字典,键为chunk的唯一标识符
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
entity_vdb: BaseVectorStorage, # 实体向量数据库
relationships_vdb: BaseVectorStorage, # 关系向量数据库
global_config: dict, # 全局配置参数
) -> Union[BaseGraphStorage, None]:
"""从文本块中提取实体和关系,并更新到知识图谱中
Args:
chunks: 待处理的文本块字典
knowledge_graph_inst: 用于存储实体和关系的知识图谱实例
entity_vdb: 用于存储实体向量的数据库
relationships_vdb: 用于存储关系向量的数据库
global_config: 包含LLM配置等全局参数
Returns:
更新后的知识图谱实例,如果提取失败则返回None
"""
# 从全局配置中获取LLM函数和最大提取次数
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
# 将chunks字典转换为有序列表
ordered_chunks = list(chunks.items())
# 获取提示模板和分隔符配置
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
# 初始化处理计数器
already_processed = 0
already_entities = 0
already_relations = 0
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
"""处理单个文本块,提取实体和关系
Args:
chunk_key_dp: (chunk_key, chunk_data)元组
Returns:
(nodes_dict, edges_dict): 提取的实体和关系字典
"""
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
# 构建初始提示并获取LLM响应
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
# 记录对话历史
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# 多轮提取循环
for now_glean_index in range(entity_extract_max_gleaning):
# 继续提取新的实体和关系
glean_result = await use_llm_func(continue_prompt, history_messages=history)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
# 询问是否需要继续提取
if_loop_result: str = await use_llm_func(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
# 分割提取结果为独立记录
records = split_string_by_multi_markers(
final_result,
[context_base["record_delimiter"], context_base["completion_delimiter"]],
)
# 用于收集可能的实体和关系
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
# 处理每条记录
for record in records:
# 提取括号内的属性内容
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
# 分割属性列表
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)
# 尝试提取实体
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
# 尝试提取关系
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
# 更新处理计数
already_processed += 1
already_entities += len(maybe_nodes)
already_relations += len(maybe_edges)
# 显示处理进度
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
# 并发处理所有文本块
results = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
print() # 清除进度条
# 合并所有提取结果
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
# 合并并更新实体到知识图谱
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
# 合并并更新关系到知识图谱
all_relationships_data = await asyncio.gather(
*[
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
# 检查提取结果
if not len(all_entities_data):
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if not len(all_relationships_data):
logger.warning(
"Didn't extract any relationships, maybe your LLM is not working"
)
return None
# 更新实体向量数据库
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
# 更新关系向量数据库
if relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
return knowledge_graph_inst
async def local_query(
query, # 用户的查询文本
knowledge_graph_inst: BaseGraphStorage, # 知识图谱存储实例
entities_vdb: BaseVectorStorage, # 实体的向量数据库
relationships_vdb: BaseVectorStorage, # 关系的向量数据库
text_chunks_db: BaseKVStorage[TextChunkSchema], # 原始文本块的存储
query_param: QueryParam, # 查询参数配置
global_config: dict, # 全局配置字典
) -> str:
"""本地查询函数,从知识图谱中检索相关信息并生成回答
Args:
query: 用户输入的查询文本
knowledge_graph_inst: 用于存储和查询知识图谱的实例
entities_vdb: 用于向量检索实体的数据库
relationships_vdb: 用于向量检索关系的数据库
text_chunks_db: 存储原始文本块的键值数据库
query_param: 包含查询相关参数的配置对象
global_config: 包含全局配置的字典
Returns:
str: 查询的响应文本
"""
context = None
# 从全局配置中获取语言模型函数
use_model_func = global_config["llm_model_func"]
# 构建关键词提取的提示
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
# 尝试解析语言模型返回的关键词JSON
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
# 第一次解析失败,清理响应文本后重试
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# 解析彻底失败,返回错误响应
print(f"JSON解析错误: {e}")
return PROMPTS["fail_response"]
# 如果成功提取到关键词,构建查询上下文
if keywords:
context = await _build_local_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
# 如果只需要上下文,直接返回
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
# 构建RAG系统提示并获取语言模型响应
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
# 清理语言模型的响应文本
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("", "")
.replace("", "")
.strip()
)
return response
async def _build_local_query_context(
query, # 查询关键词
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
entities_vdb: BaseVectorStorage, # 实体向量数据库
text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储
query_param: QueryParam, # 查询参数
):
"""构建本地查询的上下文信息
基于查询关键词,从知识图谱和向量数据库中检索相关的实体、关系和原文信息,
并将它们组织成结构化的上下文
"""
# 从实体向量数据库中检索最相关的实体
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return None
# 并发获取检索到的实体节点数据
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
if not all([n is not None for n in node_datas]):
logger.warning("部分节点数据缺失,存储可能损坏")
# 获取每个实体节点的度数(连接数)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
# 合并节点信息,添加实体名称和排名
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]
# 查找与实体最相关的文本单元和关系
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
)
# 记录使用的数据量
logger.info(
f"本地查询使用了 {len(node_datas)} 个实体, {len(use_relations)} 个关系, {len(use_text_units)} 个文本单元"
)
# 构建实体信息的CSV格式数据
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entites_section_list)
# 构建关系信息的CSV格式数据
relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"]
]
for i, e in enumerate(use_relations):
relations_section_list.append(
[
i,
e["src_tgt"][0],
e["src_tgt"][1],
e["description"],
e["keywords"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
# 构建文本单元的CSV格式数据
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
# 返回格式化的上下文信息
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict], # 实体节点数据列表
query_param: QueryParam, # 查询参数配置
text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
):
"""查找与给定实体最相关的文本单元
通过分析实体的一跳邻居和关系,找出最相关的原始文本单元
"""
# 获取每个实体节点关联的文本单元ID
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in node_datas
]
# 获取每个实体的边(关系)信息
edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
# 收集所有一跳邻居节点
all_one_hop_nodes = set()
for this_edges in edges:
if not this_edges:
continue
all_one_hop_nodes.update([e[1] for e in this_edges])
# 获取一跳邻居节点的数据
all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
)
# 构建一跳邻居节点的文本单元查找表
all_one_hop_text_units_lookup = {
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
if v is not None and "source_id" in v # 检查节点数据有效性
}
# 构建所有文本单元的查找表,包含顺序和关系计数信息
all_text_units_lookup = {}
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units:
if c_id in all_text_units_lookup:
continue
# 计算文本单元与一跳邻居的关系数量
relation_counts = 0
if this_edges: # 检查边是否存在
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
relation_counts += 1
# 获取文本单元数据
chunk_data = await text_chunks_db.get_by_id(c_id)
if chunk_data is not None and "content" in chunk_data: # 检查内容有效性
all_text_units_lookup[c_id] = {
"data": chunk_data,
"order": index,
"relation_counts": relation_counts,
}
# 过滤无效数据并确保包含内容
all_text_units = [
{"id": k, **v}
for k, v in all_text_units_lookup.items()
if v is not None and v.get("data") is not None and "content" in v["data"]
]
if not all_text_units:
logger.warning("No valid text units found")
return []
# 按顺序和关系数量排序
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
# 根据token数量限制截断文本单元列表
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
)
# 只返回文本单元数据
all_text_units = [t["data"] for t in all_text_units]
return all_text_units
async def _find_most_related_edges_from_entities(
node_datas: list[dict], # 实体节点数据列表
query_param: QueryParam, # 查询参数配置
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
):
"""查找与给定实体最相关的边(关系)
获取实体的所有关系,并根据度数和权重进行排序
"""
# 获取所有相关边
all_related_edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
# 收集所有唯一的边(确保边的方向一致)
all_edges = set()
for this_edges in all_related_edges:
all_edges.update([tuple(sorted(e)) for e in this_edges])
all_edges = list(all_edges)
# 获取边的详细信息和度数
all_edges_pack = await asyncio.gather(
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
)
all_edges_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
)
# 合并边的信息
all_edges_data = [
{"src_tgt": k, "rank": d, **v}
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
if v is not None
]
# 按度数和权重排序
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
# 根据token数量限制截断边列表
all_edges_data = truncate_list_by_token_size(
all_edges_data,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context,
)
return all_edges_data
async def global_query(
query, # 用户查询文本
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
entities_vdb: BaseVectorStorage, # 实体向量数据库
relationships_vdb: BaseVectorStorage, # 关系向量数据库
text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储
query_param: QueryParam, # 查询参数配置
global_config: dict, # 全局配置
) -> str:
"""执行全局查询,基于关系的向量检索
通过分析高层概念关键词,检索相关的关系和实体,构建查询上下文
"""
context = None
# 获取语言模型函数
use_model_func = global_config["llm_model_func"]
# 提取查询关键词
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
# 解析关键词JSON结果
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
# 清理响应文本后重试解析
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
return PROMPTS["fail_response"]
# 基于关键词构建查询上下文
if keywords:
context = await _build_global_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
# 如果只需要上下文,直接返回
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
# 构建系统提示并获取语言模型响应
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
# 清理响应文本
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("", "")
.replace("", "")
.strip()
)
return response
async def _build_global_query_context(
keywords, # 查询关键词
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
entities_vdb: BaseVectorStorage, # 实体向量数据库
relationships_vdb: BaseVectorStorage, # 关系向量数据库
text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储
query_param: QueryParam, # 查询参数
):
"""构建全局查询的上下文信息
基于关系的向量检索,找出最相关的关系、实体和文本单元
"""
# 从关系向量数据库检索相关关系
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):
return None
# 获取关系的详细数据
edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
)
# 检查数据完整性
if not all([n is not None for n in edge_datas]):
logger.warning("部分边数据缺失,存储可能损坏")
# 获取关系的度数
edge_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
)
# 合并关系信息
edge_datas = [
{"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v}
for k, v, d in zip(results, edge_datas, edge_degree)
if v is not None
]
# 按度数和权重排序
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
# 根据token限制截断关系列表
edge_datas = truncate_list_by_token_size(
edge_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context,
)
# 获取相关的实体和文本单元
use_entities = await _find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst
)
use_text_units = await _find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
)
# 记录使用的数据量
logger.info(
f"全局查询使用了 {len(use_entities)} 个实体, {len(edge_datas)} 个关系, {len(use_text_units)} 个文本单元"
)
# 构建关系的CSV格式数据
relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"]
]
for i, e in enumerate(edge_datas):
relations_section_list.append(
[
i,
e["src_id"],
e["tgt_id"],
e["description"],
e["keywords"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
# 构建实体的CSV格式数据
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(use_entities):
entites_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entites_section_list)
# 构建文本单元的CSV格式数据
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
# 返回格式化的上下文信息
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_entities_from_relationships(
edge_datas: list[dict], # 关系数据列表
query_param: QueryParam, # 查询参数配置
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
):
"""从关系中找出最相关的实体节点
通过分析关系的源节点和目标节点,获取所有相关实体的详细信息
"""
# 收集所有关系中涉及的实体名称
entity_names = set()
for e in edge_datas:
entity_names.add(e["src_id"])
entity_names.add(e["tgt_id"])
# 并发获取所有实体节点的详细数据
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
)
# 并发获取所有实体节点的度数(连接数)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
)
# 合并实体信息,添加实体名称和排名
node_datas = [
{**n, "entity_name": k, "rank": d}
for k, n, d in zip(entity_names, node_datas, node_degrees)
]
# 根据token数量限制截断实体列表
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
return node_datas
async def _find_related_text_unit_from_relationships(
edge_datas: list[dict], # 关系数据列表
query_param: QueryParam, # 查询参数配置
text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
):
"""从关系中找出相关的文本单元
通过分析关系的源文本,获取所有相关的原始文本单元
"""
# 从每个关系中提取文本单元ID
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas
]
# 构建文本单元查找表,记录顺序信息
all_text_units_lookup = {}
# 获取每个文本单元的详细数据
for index, unit_list in enumerate(text_units):
for c_id in unit_list:
if c_id not in all_text_units_lookup:
all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
}
# 检查数据完整性
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
# 过滤无效数据并添加ID信息
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
# 按出现顺序排序
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
# 根据token数量限制截断文本单元列表
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
)
# 只返回文本单元数据
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
async def hybrid_query(
query, # 用户查询文本
knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例
entities_vdb: BaseVectorStorage, # 实体向量数据库
relationships_vdb: BaseVectorStorage, # 关系向量数据库
text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储
query_param: QueryParam, # 查询参数配置
global_config: dict, # 全局配置
) -> str:
"""混合查询函数,结合局部和全局查询的优势
同时提取高层和低层关键词,分别构建上下文后合并
"""
# 初始化上下文和模型函数
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"]
# 构建关键词提取提示
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
# 获取语言模型响应
result = await use_model_func(kw_prompt)
try:
# 尝试解析关键词JSON
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
# 清理响应文本后重试解析
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
# 如果有低层关键词,构建局部查询上下文
if ll_keywords:
low_level_context = await _build_local_query_context(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
# 如果有高层关键词,构建全局查询上下文
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
# 合并两种上下文
context = combine_contexts(high_level_context, low_level_context)
# 如果只需要上下文,直接返回
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
# 构建系统提示并获取语言模型响应
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
# 清理响应文本
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("", "")
.replace("", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
"""合并高层和低层上下文
从两种上下文中提取并合并实体、关系和来源信息
"""
# 定义从上下文字符串中提取各部分的内部函数
def extract_sections(context):
# 使用正则表达式匹配实体部分
entities_match = re.search(
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
# 匹配关系部分
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
# 匹配来源部分
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
# 提取匹配内容或返回空字符串
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# 从高层上下文提取内容
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
# 从低层上下文提取内容
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# 合并并去重实体信息
combined_entities = process_combine_contexts(hl_entities, ll_entities)
# 合并并去重关系信息
combined_relationships = process_combine_contexts(
hl_relationships, ll_relationships
)
# 合并并去重来源信息
combined_sources = process_combine_contexts(hl_sources, ll_sources)
# 返回格式化的合并上下文
return f"""
-----Entities-----
```csv
{combined_entities}
```
-----Relationships-----
```csv
{combined_relationships}
```
-----Sources-----
```csv
{combined_sources}
```
"""
async def naive_query(
query, # 用户查询文本
chunks_vdb: BaseVectorStorage, # 文本块向量数据库
text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储
query_param: QueryParam, # 查询参数配置
global_config: dict, # 全局配置
):
"""简单查询函数,直接基于文本块的向量检索
不使用知识图谱,仅通过向量相似度检索相关文本块
"""
# 获取语言模型函数
use_model_func = global_config["llm_model_func"]
# 从向量数据库检索相关文本块
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"]
# 获取检索结果的ID列表
chunks_ids = [r["id"] for r in results]
# 从文本块存储获取完整内容
chunks = await text_chunks_db.get_by_ids(chunks_ids)
# 根据token数量限制截断文本块列表
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
# 记录截断信息
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
# 将文本块内容拼接成一个字符串
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
# 如果只需要上下文,直接返回拼接的文本
if query_param.only_need_context:
return section
# 构建系统提示并获取语言模型响应
sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
# 清理响应文本
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("", "")
.replace("", "")
.strip()
)
return response