# 标准库导入 import asyncio # 异步IO支持 import json # JSON数据处理 import re # 正则表达式支持 from typing import Union # 类型提示:联合类型 from collections import ( Counter, # 计数器集合类 defaultdict # 带默认值的字典 ) import warnings # 警告控制 # 从本地utils模块导入工具函数 from .utils import ( logger, # 日志记录器 clean_str, # 字符串清理函数 compute_mdhash_id, # 计算MD5哈希ID decode_tokens_by_tiktoken, # tiktoken解码函数 encode_string_by_tiktoken, # tiktoken编码函数 is_float_regex, # 浮点数检查函数 list_of_list_to_csv, # 列表转CSV函数 pack_user_ass_to_openai_messages, # OpenAI消息打包函数 split_string_by_multi_markers, # 多标记字符串分割函数 truncate_list_by_token_size, # 基于token大小截断列表 process_combine_contexts, # 上下文合并处理函数 ) # 从本地base模块导入基础类 from .base import ( BaseGraphStorage, # 图存储基类 BaseKVStorage, # 键值存储基类 BaseVectorStorage, # 向量存储基类 TextChunkSchema, # 文本块模式定义 QueryParam, # 查询参数类 ) # 从本地prompt模块导入提示相关常量 from .prompt import ( GRAPH_FIELD_SEP, # 图字段分隔符 PROMPTS # 提示模板集合 ) def chunking_by_token_size( content: str, overlap_token_size=128, # 重叠部分的token数量 max_token_size=1024, # 每个chunk的最大token数量 tiktoken_model="gpt-4o" # 使用的tokenizer模型名称 ): """将长文本按照token数量切分成多个重叠的chunk Args: content (str): 需要切分的原始文本内容 overlap_token_size (int, optional): 相邻chunk之间的重叠token数. Defaults to 128. max_token_size (int, optional): 每个chunk的最大token数. Defaults to 1024. tiktoken_model (str, optional): 使用的tokenizer模型. Defaults to "gpt-4o". Returns: list[dict]: 包含切分后的chunk列表,每个chunk包含以下字段: - tokens: chunk实际包含的token数 - content: chunk的文本内容 - chunk_order_index: chunk的序号 """ # 使用指定的tokenizer模型将文本编码为tokens tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) results = [] # 按照指定的步长(max_token_size - overlap_token_size)遍历tokens # for index, start in enumerate( range(0, len(tokens), max_token_size - overlap_token_size) ): # 解码当前窗口范围内的tokens为文本 chunk_content = decode_tokens_by_tiktoken( tokens[start : start + max_token_size], model_name=tiktoken_model ) # 将chunk添加到结果列表中 results.append( { # 计算实际token数量(最后一个chunk可能不足max_token_size) "tokens": min(max_token_size, len(tokens) - start), # 去除chunk首尾的空白字符 "content": chunk_content.strip(), # 记录chunk的序号 "chunk_order_index": index, } ) return results async def _handle_entity_relation_summary( entity_or_relation_name: str, # 实体名称或关系名称 description: str, # 需要总结的描述文本 global_config: dict, # 全局配置参数 ) -> str: """对实体或关系的描述文本进行总结 当描述文本的token数量超过指定阈值时,使用LLM对其进行总结,以减少token数量 Args: entity_or_relation_name (str): 实体名称或关系名称(用于提示LLM) description (str): 需要总结的描述文本 global_config (dict): 包含LLM配置的全局参数字典,必须包含: - llm_model_func: LLM调用函数 - llm_model_max_token_size: LLM输入的最大token数 - tiktoken_model_name: 使用的tokenizer模型名称 - entity_summary_to_max_tokens: 总结后的最大token数 Returns: str: 如果原文本较短则直接返回,否则返回LLM总结后的文本 """ # 从全局配置中获取必要的参数 use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] # 计算描述文本的token数量 tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) # 如果token数量小于阈值,无需总结直接返回 if len(tokens) < summary_max_tokens: return description # 获取总结用的提示模板 prompt_template = PROMPTS["summarize_entity_descriptions"] # 截取LLM能处理的最大token数量 use_description = decode_tokens_by_tiktoken( tokens[:llm_max_tokens], model_name=tiktoken_model_name ) # 构建提示上下文 context_base = dict( entity_name=entity_or_relation_name, # 按字段分隔符分割描述文本 description_list=use_description.split(GRAPH_FIELD_SEP), ) # 格式化提示模板 use_prompt = prompt_template.format(**context_base) # 记录调试日志 logger.debug(f"Trigger summary: {entity_or_relation_name}") # 调用LLM生成总结 summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens) return summary async def _handle_single_entity_extraction( record_attributes: list[str], # 记录的属性列表 chunk_key: str, # 当前chunk的唯一标识符 ): """从记录属性中提取单个实体的信息 Args: record_attributes (list[str]): 包含实体信息的属性列表 chunk_key (str): 当前chunk的唯一标识符,用于追踪数据来源 Returns: dict: 包含实体信息的字典,如果记录无效则返回None """ # 检查记录是否为有效的实体记录 if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None # 提取并清理实体名称 entity_name = clean_str(record_attributes[1].upper()) if not entity_name.strip(): # 如果实体名称为空,返回None return None # 提取并清理实体类型和描述 entity_type = clean_str(record_attributes[2].upper()) entity_description = clean_str(record_attributes[3]) entity_source_id = chunk_key # 记录数据来源 # 返回包含实体信息的字典 return dict( entity_name=entity_name, entity_type=entity_type, description=entity_description, source_id=entity_source_id, ) async def _handle_single_relationship_extraction( record_attributes: list[str], # 记录的属性列表 chunk_key: str, # 当前chunk的唯一标识符 ): """从记录属性中提取单个关系的信息 Args: record_attributes (list[str]): 包含关系信息的属性列表 chunk_key (str): 当前chunk的唯一标识符,用于追踪数据来源 Returns: dict: 包含关系信息的字典,如果记录无效则返回None """ # 检查记录是否为有效的关系记录 if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None # 提取并清理关系的源节点和目标节点 source = clean_str(record_attributes[1].upper()) target = clean_str(record_attributes[2].upper()) # 提取并清理关系的描述和关键词 edge_description = clean_str(record_attributes[3]) edge_keywords = clean_str(record_attributes[4]) edge_source_id = chunk_key # 记录数据来源 # 提取并计算关系的权重,默认为1.0 weight = ( float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 ) # 返回包含关系信息的字典 return dict( src_id=source, tgt_id=target, weight=weight, description=edge_description, keywords=edge_keywords, source_id=edge_source_id, ) async def _merge_nodes_then_upsert( entity_name: str, # 实体名称 nodes_data: list[dict], # 待合并的节点数据列表 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 global_config: dict, # 全局配置参数 ): """合并节点数据并更新到知识图谱中 Args: entity_name (str): 实体名称 nodes_data (list[dict]): 待合并的节点数据列表 knowledge_graph_inst (BaseGraphStorage): 知识图谱实例 global_config (dict): 全局配置参数 Returns: dict: 更新后的节点数据 """ already_entitiy_types = [] already_source_ids = [] already_description = [] # 检查知识图谱中是否已存在该节点 already_node = await knowledge_graph_inst.get_node(entity_name) if already_node is not None: # 如果存在,提取已有的实体类型、来源ID和描述 already_entitiy_types.append(already_node["entity_type"]) already_source_ids.extend( split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) ) already_description.append(already_node["description"]) # 计算合并后的实体类型(选择出现频率最高的类型) entity_type = sorted( Counter( [dp["entity_type"] for dp in nodes_data] + already_entitiy_types ).items(), key=lambda x: x[1], reverse=True, )[0][0] # 合并描述和来源ID description = GRAPH_FIELD_SEP.join( sorted(set([dp["description"] for dp in nodes_data] + already_description)) ) source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in nodes_data] + already_source_ids) ) # 使用LLM对合并后的描述进行总结 description = await _handle_entity_relation_summary( entity_name, description, global_config ) # 构建节点数据字典 node_data = dict( entity_type=entity_type, description=description, source_id=source_id, ) # 更新或插入节点到知识图谱中 await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) node_data["entity_name"] = entity_name return node_data async def _merge_edges_then_upsert( src_id: str, # 源节点ID tgt_id: str, # 目标节点ID edges_data: list[dict], # 待合并的边数据列表 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 global_config: dict, # 全局配置参数 ): """合并边数据并更新到知识图谱中 Args: src_id (str): 源节点ID tgt_id (str): 目标节点ID edges_data (list[dict]): 待合并的边数据列表 knowledge_graph_inst (BaseGraphStorage): 知识图谱实例 global_config (dict): 全局配置参数 Returns: dict: 更新后的边数据 """ already_weights = [] already_source_ids = [] already_description = [] already_keywords = [] # 检查知识图谱中是否已存在该边 if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) # 如果存在,提取已有的权重、来源ID、描述和关键词 already_weights.append(already_edge["weight"]) already_source_ids.extend( split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) ) already_description.append(already_edge["description"]) already_keywords.extend( split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP]) ) # 计算合并后的权重(累加) weight = sum([dp["weight"] for dp in edges_data] + already_weights) # 合并描述、关键词和来源ID description = GRAPH_FIELD_SEP.join( sorted(set([dp["description"] for dp in edges_data] + already_description)) ) keywords = GRAPH_FIELD_SEP.join( sorted(set([dp["keywords"] for dp in edges_data] + already_keywords)) ) source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in edges_data] + already_source_ids) ) # 确保源节点和目标节点存在于知识图谱中 for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): await knowledge_graph_inst.upsert_node( need_insert_id, node_data={ "source_id": source_id, "description": description, "entity_type": '"UNKNOWN"', }, ) # 使用LLM对合并后的描述进行总结 description = await _handle_entity_relation_summary( (src_id, tgt_id), description, global_config ) # 更新或插入边到知识图谱中 await knowledge_graph_inst.upsert_edge( src_id, tgt_id, edge_data=dict( weight=weight, description=description, keywords=keywords, source_id=source_id, ), ) edge_data = dict( src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, ) return edge_data async def extract_entities( chunks: dict[str, TextChunkSchema], # 文本块字典,键为chunk的唯一标识符 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 entity_vdb: BaseVectorStorage, # 实体向量数据库 relationships_vdb: BaseVectorStorage, # 关系向量数据库 global_config: dict, # 全局配置参数 ) -> Union[BaseGraphStorage, None]: """从文本块中提取实体和关系,并更新到知识图谱中 Args: chunks: 待处理的文本块字典 knowledge_graph_inst: 用于存储实体和关系的知识图谱实例 entity_vdb: 用于存储实体向量的数据库 relationships_vdb: 用于存储关系向量的数据库 global_config: 包含LLM配置等全局参数 Returns: 更新后的知识图谱实例,如果提取失败则返回None """ # 从全局配置中获取LLM函数和最大提取次数 use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] # 将chunks字典转换为有序列表 ordered_chunks = list(chunks.items()) # 获取提示模板和分隔符配置 entity_extract_prompt = PROMPTS["entity_extraction"] context_base = dict( tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]), ) continue_prompt = PROMPTS["entiti_continue_extraction"] if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] # 初始化处理计数器 already_processed = 0 already_entities = 0 already_relations = 0 async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): """处理单个文本块,提取实体和关系 Args: chunk_key_dp: (chunk_key, chunk_data)元组 Returns: (nodes_dict, edges_dict): 提取的实体和关系字典 """ nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] # 构建初始提示并获取LLM响应 hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) final_result = await use_llm_func(hint_prompt) # 记录对话历史 history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # 多轮提取循环 for now_glean_index in range(entity_extract_max_gleaning): # 继续提取新的实体和关系 glean_result = await use_llm_func(continue_prompt, history_messages=history) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) final_result += glean_result if now_glean_index == entity_extract_max_gleaning - 1: break # 询问是否需要继续提取 if_loop_result: str = await use_llm_func( if_loop_prompt, history_messages=history ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break # 分割提取结果为独立记录 records = split_string_by_multi_markers( final_result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) # 用于收集可能的实体和关系 maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) # 处理每条记录 for record in records: # 提取括号内的属性内容 record = re.search(r"\((.*)\)", record) if record is None: continue record = record.group(1) # 分割属性列表 record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) # 尝试提取实体 if_entities = await _handle_single_entity_extraction( record_attributes, chunk_key ) if if_entities is not None: maybe_nodes[if_entities["entity_name"]].append(if_entities) continue # 尝试提取关系 if_relation = await _handle_single_relationship_extraction( record_attributes, chunk_key ) if if_relation is not None: maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( if_relation ) # 更新处理计数 already_processed += 1 already_entities += len(maybe_nodes) already_relations += len(maybe_edges) # 显示处理进度 now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) ] print( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", end="", flush=True, ) return dict(maybe_nodes), dict(maybe_edges) # 并发处理所有文本块 results = await asyncio.gather( *[_process_single_content(c) for c in ordered_chunks] ) print() # 清除进度条 # 合并所有提取结果 maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) for m_nodes, m_edges in results: for k, v in m_nodes.items(): maybe_nodes[k].extend(v) for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) # 合并并更新实体到知识图谱 all_entities_data = await asyncio.gather( *[ _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) for k, v in maybe_nodes.items() ] ) # 合并并更新关系到知识图谱 all_relationships_data = await asyncio.gather( *[ _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) for k, v in maybe_edges.items() ] ) # 检查提取结果 if not len(all_entities_data): logger.warning("Didn't extract any entities, maybe your LLM is not working") return None if not len(all_relationships_data): logger.warning( "Didn't extract any relationships, maybe your LLM is not working" ) return None # 更新实体向量数据库 if entity_vdb is not None: data_for_vdb = { compute_mdhash_id(dp["entity_name"], prefix="ent-"): { "content": dp["entity_name"] + dp["description"], "entity_name": dp["entity_name"], } for dp in all_entities_data } await entity_vdb.upsert(data_for_vdb) # 更新关系向量数据库 if relationships_vdb is not None: data_for_vdb = { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"], } for dp in all_relationships_data } await relationships_vdb.upsert(data_for_vdb) return knowledge_graph_inst async def local_query( query, # 用户的查询文本 knowledge_graph_inst: BaseGraphStorage, # 知识图谱存储实例 entities_vdb: BaseVectorStorage, # 实体的向量数据库 relationships_vdb: BaseVectorStorage, # 关系的向量数据库 text_chunks_db: BaseKVStorage[TextChunkSchema], # 原始文本块的存储 query_param: QueryParam, # 查询参数配置 global_config: dict, # 全局配置字典 ) -> str: """本地查询函数,从知识图谱中检索相关信息并生成回答 Args: query: 用户输入的查询文本 knowledge_graph_inst: 用于存储和查询知识图谱的实例 entities_vdb: 用于向量检索实体的数据库 relationships_vdb: 用于向量检索关系的数据库 text_chunks_db: 存储原始文本块的键值数据库 query_param: 包含查询相关参数的配置对象 global_config: 包含全局配置的字典 Returns: str: 查询的响应文本 """ context = None # 从全局配置中获取语言模型函数 use_model_func = global_config["llm_model_func"] # 构建关键词提取的提示 kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) # 尝试解析语言模型返回的关键词JSON try: keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError: # 第一次解析失败,清理响应文本后重试 try: result = ( result.replace(kw_prompt[:-1], "") .replace("user", "") .replace("model", "") .strip() ) result = "{" + result.split("{")[1].split("}")[0] + "}" keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError as e: # 解析彻底失败,返回错误响应 print(f"JSON解析错误: {e}") return PROMPTS["fail_response"] # 如果成功提取到关键词,构建查询上下文 if keywords: context = await _build_local_query_context( keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) # 如果只需要上下文,直接返回 if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] # 构建RAG系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) response = await use_model_func( query, system_prompt=sys_prompt, ) # 清理语言模型的响应文本 if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) return response async def _build_local_query_context( query, # 查询关键词 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 entities_vdb: BaseVectorStorage, # 实体向量数据库 text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 query_param: QueryParam, # 查询参数 ): """构建本地查询的上下文信息 基于查询关键词,从知识图谱和向量数据库中检索相关的实体、关系和原文信息, 并将它们组织成结构化的上下文 """ # 从实体向量数据库中检索最相关的实体 results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return None # 并发获取检索到的实体节点数据 node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] ) if not all([n is not None for n in node_datas]): logger.warning("部分节点数据缺失,存储可能损坏") # 获取每个实体节点的度数(连接数) node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] ) # 合并节点信息,添加实体名称和排名 node_datas = [ {**n, "entity_name": k["entity_name"], "rank": d} for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # 查找与实体最相关的文本单元和关系 use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst ) use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst ) # 记录使用的数据量 logger.info( f"本地查询使用了 {len(node_datas)} 个实体, {len(use_relations)} 个关系, {len(use_text_units)} 个文本单元" ) # 构建实体信息的CSV格式数据 entites_section_list = [["id", "entity", "type", "description", "rank"]] for i, n in enumerate(node_datas): entites_section_list.append( [ i, n["entity_name"], n.get("entity_type", "UNKNOWN"), n.get("description", "UNKNOWN"), n["rank"], ] ) entities_context = list_of_list_to_csv(entites_section_list) # 构建关系信息的CSV格式数据 relations_section_list = [ ["id", "source", "target", "description", "keywords", "weight", "rank"] ] for i, e in enumerate(use_relations): relations_section_list.append( [ i, e["src_tgt"][0], e["src_tgt"][1], e["description"], e["keywords"], e["weight"], e["rank"], ] ) relations_context = list_of_list_to_csv(relations_section_list) # 构建文本单元的CSV格式数据 text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) # 返回格式化的上下文信息 return f""" -----Entities----- ```csv {entities_context} ``` -----Relationships----- ```csv {relations_context} ``` -----Sources----- ```csv {text_units_context} ``` """ async def _find_most_related_text_unit_from_entities( node_datas: list[dict], # 实体节点数据列表 query_param: QueryParam, # 查询参数配置 text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): """查找与给定实体最相关的文本单元 通过分析实体的一跳邻居和关系,找出最相关的原始文本单元 """ # 获取每个实体节点关联的文本单元ID text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in node_datas ] # 获取每个实体的边(关系)信息 edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] ) # 收集所有一跳邻居节点 all_one_hop_nodes = set() for this_edges in edges: if not this_edges: continue all_one_hop_nodes.update([e[1] for e in this_edges]) # 获取一跳邻居节点的数据 all_one_hop_nodes = list(all_one_hop_nodes) all_one_hop_nodes_data = await asyncio.gather( *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes] ) # 构建一跳邻居节点的文本单元查找表 all_one_hop_text_units_lookup = { k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP])) for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data) if v is not None and "source_id" in v # 检查节点数据有效性 } # 构建所有文本单元的查找表,包含顺序和关系计数信息 all_text_units_lookup = {} for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: if c_id in all_text_units_lookup: continue # 计算文本单元与一跳邻居的关系数量 relation_counts = 0 if this_edges: # 检查边是否存在 for e in this_edges: if ( e[1] in all_one_hop_text_units_lookup and c_id in all_one_hop_text_units_lookup[e[1]] ): relation_counts += 1 # 获取文本单元数据 chunk_data = await text_chunks_db.get_by_id(c_id) if chunk_data is not None and "content" in chunk_data: # 检查内容有效性 all_text_units_lookup[c_id] = { "data": chunk_data, "order": index, "relation_counts": relation_counts, } # 过滤无效数据并确保包含内容 all_text_units = [ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None and v.get("data") is not None and "content" in v["data"] ] if not all_text_units: logger.warning("No valid text units found") return [] # 按顺序和关系数量排序 all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) # 根据token数量限制截断文本单元列表 all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, ) # 只返回文本单元数据 all_text_units = [t["data"] for t in all_text_units] return all_text_units async def _find_most_related_edges_from_entities( node_datas: list[dict], # 实体节点数据列表 query_param: QueryParam, # 查询参数配置 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): """查找与给定实体最相关的边(关系) 获取实体的所有关系,并根据度数和权重进行排序 """ # 获取所有相关边 all_related_edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] ) # 收集所有唯一的边(确保边的方向一致) all_edges = set() for this_edges in all_related_edges: all_edges.update([tuple(sorted(e)) for e in this_edges]) all_edges = list(all_edges) # 获取边的详细信息和度数 all_edges_pack = await asyncio.gather( *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges] ) all_edges_degree = await asyncio.gather( *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges] ) # 合并边的信息 all_edges_data = [ {"src_tgt": k, "rank": d, **v} for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree) if v is not None ] # 按度数和权重排序 all_edges_data = sorted( all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True ) # 根据token数量限制截断边列表 all_edges_data = truncate_list_by_token_size( all_edges_data, key=lambda x: x["description"], max_token_size=query_param.max_token_for_global_context, ) return all_edges_data async def global_query( query, # 用户查询文本 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 entities_vdb: BaseVectorStorage, # 实体向量数据库 relationships_vdb: BaseVectorStorage, # 关系向量数据库 text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 query_param: QueryParam, # 查询参数配置 global_config: dict, # 全局配置 ) -> str: """执行全局查询,基于关系的向量检索 通过分析高层概念关键词,检索相关的关系和实体,构建查询上下文 """ context = None # 获取语言模型函数 use_model_func = global_config["llm_model_func"] # 提取查询关键词 kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) # 解析关键词JSON结果 try: keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError: # 清理响应文本后重试解析 try: result = ( result.replace(kw_prompt[:-1], "") .replace("user", "") .replace("model", "") .strip() ) result = "{" + result.split("{")[1].split("}")[0] + "}" keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError as e: print(f"JSON解析错误: {e}") return PROMPTS["fail_response"] # 基于关键词构建查询上下文 if keywords: context = await _build_global_query_context( keywords, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, ) # 如果只需要上下文,直接返回 if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] # 构建系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) response = await use_model_func( query, system_prompt=sys_prompt, ) # 清理响应文本 if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) return response async def _build_global_query_context( keywords, # 查询关键词 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 entities_vdb: BaseVectorStorage, # 实体向量数据库 relationships_vdb: BaseVectorStorage, # 关系向量数据库 text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 query_param: QueryParam, # 查询参数 ): """构建全局查询的上下文信息 基于关系的向量检索,找出最相关的关系、实体和文本单元 """ # 从关系向量数据库检索相关关系 results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): return None # 获取关系的详细数据 edge_datas = await asyncio.gather( *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] ) # 检查数据完整性 if not all([n is not None for n in edge_datas]): logger.warning("部分边数据缺失,存储可能损坏") # 获取关系的度数 edge_degree = await asyncio.gather( *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results] ) # 合并关系信息 edge_datas = [ {"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v} for k, v, d in zip(results, edge_datas, edge_degree) if v is not None ] # 按度数和权重排序 edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) # 根据token限制截断关系列表 edge_datas = truncate_list_by_token_size( edge_datas, key=lambda x: x["description"], max_token_size=query_param.max_token_for_global_context, ) # 获取相关的实体和文本单元 use_entities = await _find_most_related_entities_from_relationships( edge_datas, query_param, knowledge_graph_inst ) use_text_units = await _find_related_text_unit_from_relationships( edge_datas, query_param, text_chunks_db, knowledge_graph_inst ) # 记录使用的数据量 logger.info( f"全局查询使用了 {len(use_entities)} 个实体, {len(edge_datas)} 个关系, {len(use_text_units)} 个文本单元" ) # 构建关系的CSV格式数据 relations_section_list = [ ["id", "source", "target", "description", "keywords", "weight", "rank"] ] for i, e in enumerate(edge_datas): relations_section_list.append( [ i, e["src_id"], e["tgt_id"], e["description"], e["keywords"], e["weight"], e["rank"], ] ) relations_context = list_of_list_to_csv(relations_section_list) # 构建实体的CSV格式数据 entites_section_list = [["id", "entity", "type", "description", "rank"]] for i, n in enumerate(use_entities): entites_section_list.append( [ i, n["entity_name"], n.get("entity_type", "UNKNOWN"), n.get("description", "UNKNOWN"), n["rank"], ] ) entities_context = list_of_list_to_csv(entites_section_list) # 构建文本单元的CSV格式数据 text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) # 返回格式化的上下文信息 return f""" -----Entities----- ```csv {entities_context} ``` -----Relationships----- ```csv {relations_context} ``` -----Sources----- ```csv {text_units_context} ``` """ async def _find_most_related_entities_from_relationships( edge_datas: list[dict], # 关系数据列表 query_param: QueryParam, # 查询参数配置 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): """从关系中找出最相关的实体节点 通过分析关系的源节点和目标节点,获取所有相关实体的详细信息 """ # 收集所有关系中涉及的实体名称 entity_names = set() for e in edge_datas: entity_names.add(e["src_id"]) entity_names.add(e["tgt_id"]) # 并发获取所有实体节点的详细数据 node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] ) # 并发获取所有实体节点的度数(连接数) node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names] ) # 合并实体信息,添加实体名称和排名 node_datas = [ {**n, "entity_name": k, "rank": d} for k, n, d in zip(entity_names, node_datas, node_degrees) ] # 根据token数量限制截断实体列表 node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"], max_token_size=query_param.max_token_for_local_context, ) return node_datas async def _find_related_text_unit_from_relationships( edge_datas: list[dict], # 关系数据列表 query_param: QueryParam, # 查询参数配置 text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 ): """从关系中找出相关的文本单元 通过分析关系的源文本,获取所有相关的原始文本单元 """ # 从每个关系中提取文本单元ID text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in edge_datas ] # 构建文本单元查找表,记录顺序信息 all_text_units_lookup = {} # 获取每个文本单元的详细数据 for index, unit_list in enumerate(text_units): for c_id in unit_list: if c_id not in all_text_units_lookup: all_text_units_lookup[c_id] = { "data": await text_chunks_db.get_by_id(c_id), "order": index, } # 检查数据完整性 if any([v is None for v in all_text_units_lookup.values()]): logger.warning("Text chunks are missing, maybe the storage is damaged") # 过滤无效数据并添加ID信息 all_text_units = [ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None ] # 按出现顺序排序 all_text_units = sorted(all_text_units, key=lambda x: x["order"]) # 根据token数量限制截断文本单元列表 all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, ) # 只返回文本单元数据 all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units] return all_text_units async def hybrid_query( query, # 用户查询文本 knowledge_graph_inst: BaseGraphStorage, # 知识图谱实例 entities_vdb: BaseVectorStorage, # 实体向量数据库 relationships_vdb: BaseVectorStorage, # 关系向量数据库 text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 query_param: QueryParam, # 查询参数配置 global_config: dict, # 全局配置 ) -> str: """混合查询函数,结合局部和全局查询的优势 同时提取高层和低层关键词,分别构建上下文后合并 """ # 初始化上下文和模型函数 low_level_context = None high_level_context = None use_model_func = global_config["llm_model_func"] # 构建关键词提取提示 kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) # 获取语言模型响应 result = await use_model_func(kw_prompt) try: # 尝试解析关键词JSON keywords_data = json.loads(result) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) hl_keywords = ", ".join(hl_keywords) ll_keywords = ", ".join(ll_keywords) except json.JSONDecodeError: # 清理响应文本后重试解析 try: result = ( result.replace(kw_prompt[:-1], "") .replace("user", "") .replace("model", "") .strip() ) result = "{" + result.split("{")[1].split("}")[0] + "}" keywords_data = json.loads(result) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) hl_keywords = ", ".join(hl_keywords) ll_keywords = ", ".join(ll_keywords) except json.JSONDecodeError as e: print(f"JSON parsing error: {e}") return PROMPTS["fail_response"] # 如果有低层关键词,构建局部查询上下文 if ll_keywords: low_level_context = await _build_local_query_context( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) # 如果有高层关键词,构建全局查询上下文 if hl_keywords: high_level_context = await _build_global_query_context( hl_keywords, knowledge_graph_inst, entities_vdb, relationships_vdb, text_chunks_db, query_param, ) # 合并两种上下文 context = combine_contexts(high_level_context, low_level_context) # 如果只需要上下文,直接返回 if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] # 构建系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) response = await use_model_func( query, system_prompt=sys_prompt, ) # 清理响应文本 if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) return response def combine_contexts(high_level_context, low_level_context): """合并高层和低层上下文 从两种上下文中提取并合并实体、关系和来源信息 """ # 定义从上下文字符串中提取各部分的内部函数 def extract_sections(context): # 使用正则表达式匹配实体部分 entities_match = re.search( r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL ) # 匹配关系部分 relationships_match = re.search( r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL ) # 匹配来源部分 sources_match = re.search( r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL ) # 提取匹配内容或返回空字符串 entities = entities_match.group(1) if entities_match else "" relationships = relationships_match.group(1) if relationships_match else "" sources = sources_match.group(1) if sources_match else "" return entities, relationships, sources # 从高层上下文提取内容 if high_level_context is None: warnings.warn( "High Level context is None. Return empty High entity/relationship/source" ) hl_entities, hl_relationships, hl_sources = "", "", "" else: hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) # 从低层上下文提取内容 if low_level_context is None: warnings.warn( "Low Level context is None. Return empty Low entity/relationship/source" ) ll_entities, ll_relationships, ll_sources = "", "", "" else: ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) # 合并并去重实体信息 combined_entities = process_combine_contexts(hl_entities, ll_entities) # 合并并去重关系信息 combined_relationships = process_combine_contexts( hl_relationships, ll_relationships ) # 合并并去重来源信息 combined_sources = process_combine_contexts(hl_sources, ll_sources) # 返回格式化的合并上下文 return f""" -----Entities----- ```csv {combined_entities} ``` -----Relationships----- ```csv {combined_relationships} ``` -----Sources----- ```csv {combined_sources} ``` """ async def naive_query( query, # 用户查询文本 chunks_vdb: BaseVectorStorage, # 文本块向量数据库 text_chunks_db: BaseKVStorage[TextChunkSchema], # 文本块存储 query_param: QueryParam, # 查询参数配置 global_config: dict, # 全局配置 ): """简单查询函数,直接基于文本块的向量检索 不使用知识图谱,仅通过向量相似度检索相关文本块 """ # 获取语言模型函数 use_model_func = global_config["llm_model_func"] # 从向量数据库检索相关文本块 results = await chunks_vdb.query(query, top_k=query_param.top_k) if not len(results): return PROMPTS["fail_response"] # 获取检索结果的ID列表 chunks_ids = [r["id"] for r in results] # 从文本块存储获取完整内容 chunks = await text_chunks_db.get_by_ids(chunks_ids) # 根据token数量限制截断文本块列表 maybe_trun_chunks = truncate_list_by_token_size( chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, ) # 记录截断信息 logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") # 将文本块内容拼接成一个字符串 section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) # 如果只需要上下文,直接返回拼接的文本 if query_param.only_need_context: return section # 构建系统提示并获取语言模型响应 sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=section, response_type=query_param.response_type ) response = await use_model_func( query, system_prompt=sys_prompt, ) # 清理响应文本 if len(response) > len(sys_prompt): response = ( response[len(sys_prompt) :] .replace(sys_prompt, "") .replace("user", "") .replace("model", "") .replace(query, "") .replace("", "") .replace("", "") .strip() ) return response