众力资讯网

GraphRAG进阶:基于Neo4j与LlamaIndex的DRIFT搜索实现详解

微软的GraphRAG算得上是最早一批成熟的GraphRAG系统,它把索引阶段(抽取实体、关系、构建层级社区并生成摘要)

微软的GraphRAG算得上是最早一批成熟的GraphRAG系统,它把索引阶段(抽取实体、关系、构建层级社区并生成摘要)和查询阶段的高级能力整合到了一起。这套方案的优势在于,可以借助预先计算好的实体、关系、社区摘要来回答那些宏观的、主题性的问题,这恰恰是传统RAG系统基于文档检索难以做到的。

本文的重点是DRIFT搜索:Dynamic Reasoning and Inference with Flexible Traversal,翻译过来就是"动态推理与灵活遍历"。这是一种相对较新的检索策略,兼具全局搜索和局部搜索的特点。

DRIFT的工作流程是这样的:先通过向量搜索建立一个宽泛的查询起点,再利用群信息把原始问题拆解成更细粒度的后续查询。然后动态地在知识图谱上游走,抓取实体、关系等局部细节。这种设计在计算效率和答案质量之间找到了一个不错的平衡点。

上图为使用 LlamaIndex 工作流和 Neo4j 实现的 DRIFT 搜索,核心流程分一下几步:

首先是HyDE生成,基于一份样例社区报告构造假设性答案,用来改善查询的向量表示。

接着社区搜索登场,通过向量相似度找出最相关的社区报告,给查询提供宏观上下文。系统会分析这些结果,输出一个初步的中间答案,同时生成一批后续查询用于深挖。

这些后续查询会在局部搜索阶段并行执行,从知识图谱里捞出文本块、实体、关系、以及更多社区报告。这个过程可以迭代多轮,每轮都可能产生新的后续查询。

最后是答案生成,把过程中积累的所有中间答案汇总起来,融合社区级别的宏观洞察和局部细节,生成最终响应。整体思路就是先铺开、再聚焦,层层递进。

本文用的是《爱丽丝梦游仙境》,刘易斯·卡罗尔的经典作品,这部小说角色众多、场景丰富、事件环环相扣,拿来演示GraphRAG的能力再合适不过。

数据导入

整个pipeline遵循标准的GraphRAG流程,分三个阶段:

class MSGraphRAGIngestion(Workflow):      @step      async def entity_extraction(self, ev: StartEvent) -> EntitySummarization:          chunks = splitter.split_text(ev.text)          await ms_graph.extract_nodes_and_rels(chunks, ev.allowed_entities)          return EntitySummarization()        @step      async def entity_summarization(          self, ev: EntitySummarization      ) -> CommunitySummarization:          await ms_graph.summarize_nodes_and_rels()          return CommunitySummarization()        @step      async def community_summarization(          self, ev: CommunitySummarization      ) -> CommunityEmbeddings:          await ms_graph.summarize_communities()          return CommunityEmbeddings()

先从文本块里抽取实体和关系,再给节点和关系生成摘要,最后构建层级社区并生成社区摘要。

摘要做完之后,要给社区和实体都生成向量嵌入,这样才能支持相似性检索。社区嵌入的代码长这样:

@step      async def community_embeddings(self, ev: CommunityEmbeddings) -> EntityEmbeddings:          # Fetch all communities from the graph database          communities = ms_graph.query(              """      MATCH (c:__Community__)      WHERE c.summary IS NOT NULL AND c.rating > $min_community_rating      RETURN coalesce(c.title, "") + " " + c.summary AS community_description, c.id AS community_id      """,              params={"min_community_rating": MIN_COMMUNITY_RATING},          )          if communities:              # Generate vector embeddings from community descriptions              response = await client.embeddings.create(                  input=[c["community_description"] for c in communities],                  model=TEXT_EMBEDDING_MODEL,              )              # Store embeddings in the graph and create vector index              embeds = [                  {                      "community_id": community["community_id"],                      "embedding": embedding.embedding,                  }                  for community, embedding in zip(communities, response.data)              ]              ms_graph.query(                  """UNWIND $data as row              MATCH (c:__Community__ {id: row.community_id})              CALL db.create.setNodeVectorProperty(c, 'embedding', row.embedding)""",                  params={"data": embeds},              )              ms_graph.query(                  "CREATE VECTOR INDEX community IF NOT EXISTS FOR (c:__Community__) ON c.embedding"              )          return EntityEmbeddings()

实体嵌入同理,这样DRIFT搜索需要的向量索引就都建好了。

DRIFT搜索

DRIFT的检索思路其实很符合简单:先看大图,再挖细节。它不会一上来就在文档或实体层面做精确匹配,而是先去查群的摘要,因为这些摘要是对知识图谱主要主题的高层次概括。

拿到相关的高层信息后,DRIFT会智能地派生出后续查询,去精确检索特定实体、关系、源文档。这种两阶段的做法其实很像人类查资料的习惯:先大致了解情况再针对性地追问细节。既有全局搜索的覆盖面,又有局部搜索的精准度,而且不用把所有社区报告或文档都过一遍,计算开销控制得不错。

下面拆解一下各个阶段的实现。

群搜索

DRIFT用了HyDE技术来提升向量检索的准确率。不是直接拿用户query做embedding,而是先让模型生成一个假设性的答案,再用这个答案去做相似性搜索。道理很简单:假设答案在语义上跟真实的摘要更接近。

@step   async def hyde_generation(self, ev: StartEvent) -> CommunitySearch:      # Fetch a random community report to use as a template for HyDE generation      random_community_report = driver.execute_query(          """      MATCH (c:__Community__)      WHERE c.summary IS NOT NULL      RETURN coalesce(c.title, "") + " " + c.summary AS community_description""",          result_transformer_=lambda r: r.data(),      )      # Generate a hypothetical answer to improve query representation      hyde = HYDE_PROMPT.format(          query=ev.query, template=random_community_report[0]["community_description"]      )      hyde_response = await client.responses.create(          model="gpt-5-mini",          input=[{"role": "user", "content": hyde}],          reasoning={"effort": "low"},      )      return CommunitySearch(query=ev.query, hyde_query=hyde_response.output_text)

拿到HyDE query之后,做embedding,然后通过向量相似度捞出top 5的报告。接着让LLM基于这些报告生成一个初步答案,同时识别出需要深挖的后续查询。将初步答案存起来然后进行后续查询全部并行分发到局部搜索阶段。

@step   async def community_search(self, ctx: Context, ev: CommunitySearch) -> LocalSearch:      # Create embedding from the HyDE-enhanced query      embedding_response = await client.embeddings.create(          input=ev.hyde_query, model=TEXT_EMBEDDING_MODEL      )      embedding = embedding_response.data[0].embedding            # Find top 5 most relevant community reports via vector similarity      community_reports = driver.execute_query(          """      CALL db.index.vector.queryNodes('community', 5, $embedding) YIELD node, score      RETURN 'community-' + node.id AS source_id, node.summary AS community_summary      """,          result_transformer_=lambda r: r.data(),          embedding=embedding,      )            # Generate initial answer and identify what additional info is needed      initial_prompt = DRIFT_PRIMER_PROMPT.format(          query=ev.query, community_reports=community_reports      )      initial_response = await client.responses.create(          model="gpt-5-mini",          input=[{"role": "user", "content": initial_prompt}],          reasoning={"effort": "low"},      )      response_json = json_repair.loads(initial_response.output_text)      print(f"Initial intermediate response: {response_json['intermediate_answer']}")            # Store the initial answer and prepare for parallel local searches      async with ctx.store.edit_state() as ctx_state:          ctx_state["intermediate_answers"] = [              {                  "intermediate_answer": response_json["intermediate_answer"],                  "score": response_json["score"],              }          ]          ctx_state["local_search_num"] = len(response_json["follow_up_queries"])            # Dispatch follow-up queries to run in parallel      for local_query in response_json["follow_up_queries"]:          ctx.send_event(LocalSearch(query=ev.query, local_query=local_query))      return None

这就是DRIFT的核心思路,先用HyDE增强的社区搜索铺开,再用后续查询往下钻。

局部搜索

局部搜索阶段把后续查询并行跑起来,深入到具体细节。每个查询通过实体向量检索拿到目标上下文,生成中间答案,可能还会产出更多后续查询。

@step(num_workers=5)   async def local_search(self, ev: LocalSearch) -> LocalSearchResults:      print(f"Running local query: {ev.local_query}")            # Create embedding for the local query      response = await client.embeddings.create(          input=ev.local_query, model=TEXT_EMBEDDING_MODEL      )      embedding = response.data[0].embedding            # Retrieve relevant entities and gather their associated context:      # - Text chunks where entities are mentioned      # - Community reports the entities belong to      # - Relationships between the retrieved entities      # - Entity descriptions      local_reports = driver.execute_query(          """   CALL db.index.vector.queryNodes('entity', 5, $embedding) YIELD node, score   WITH collect(node) AS nodes   WITH   collect {    UNWIND nodes as n    MATCH (n)<-[:MENTIONS]->(c:__Chunk__)    WITH c, count(distinct n) as freq    RETURN {chunkText: c.text, source_id: 'chunk-' + c.id}    ORDER BY freq DESC    LIMIT 3   } AS text_mapping,   collect {    UNWIND nodes as n    MATCH (n)-[:IN_COMMUNITY*]->(c:__Community__)    WHERE c.summary IS NOT NULL    WITH c, c.rating as rank    RETURN {summary: c.summary, source_id: 'community-' + c.id}    ORDER BY rank DESC    LIMIT 3   } AS report_mapping,   collect {    UNWIND nodes as n    MATCH (n)-[r:SUMMARIZED_RELATIONSHIP]-(m)    WHERE m IN nodes    RETURN {descriptionText: r.summary, source_id: 'relationship-' + n.name + '-' + m.name}   LIMIT 3   } as insideRels,   collect {    UNWIND nodes as n    RETURN {descriptionText: n.summary, source_id: 'node-' + n.name}   } as entities   RETURN {Chunks: text_mapping, Reports: report_mapping,     Relationships: insideRels,     Entities: entities} AS output   """,          result_transformer_=lambda r: r.data(),          embedding=embedding,      )            # Generate answer based on the retrieved context      local_prompt = DRIFT_LOCAL_SYSTEM_PROMPT.format(          response_type=DEFAULT_RESPONSE_TYPE,          context_data=local_reports,          global_query=ev.query,      )      local_response = await client.responses.create(          model="gpt-5-mini",          input=[{"role": "user", "content": local_prompt}],          reasoning={"effort": "low"},      )      response_json = json_repair.loads(local_response.output_text)            # Limit follow-up queries to prevent exponential growth      response_json["follow_up_queries"] = response_json["follow_up_queries"][:LOCAL_TOP_K]            return LocalSearchResults(results=response_json, query=ev.query)

下一步负责编排迭代深化的过程。用collect_events等所有并行搜索跑完,然后判断要不要继续往下挖。如果当前深度还没到上限(这里设的max depth=2),就把所有结果里的后续查询提取出来,存好中间答案分发下一轮并行搜索。

@step   async def local_search_results(      self, ctx: Context, ev: LocalSearchResults   ) -> LocalSearch | FinalAnswer:      local_search_num = await ctx.store.get("local_search_num")            # Wait for all parallel searches to complete      results = ctx.collect_events(ev, [LocalSearchResults] * local_search_num)      if results is None:          return None                intermediate_results = [          {              "intermediate_answer": event.results["response"],              "score": event.results["score"],          }          for event in results      ]      current_depth = await ctx.store.get("local_search_depth", default=1)      query = [ev.query for ev in results][0]        # Continue drilling down if we haven't reached max depth      if current_depth < MAX_LOCAL_SEARCH_DEPTH:          await ctx.store.set("local_search_depth", current_depth + 1)          follow_up_queries = [              query              for event in results              for query in event.results["follow_up_queries"]          ]                    # Store intermediate answers and dispatch next round of searches          async with ctx.store.edit_state() as ctx_state:              ctx_state["intermediate_answers"].extend(intermediate_results)              ctx_state["local_search_num"] = len(follow_up_queries)            for local_query in follow_up_queries:              ctx.send_event(LocalSearch(query=query, local_query=local_query))          return None      else:          return FinalAnswer(query=query)

这样就形成了一个迭代细化的循环,每一层都在前一层的基础上继续深挖。达到最大深度后,触发最终答案生成。

最终答案

最后一步把整个DRIFT搜索过程中积攒的所有中间答案汇总成一个完整的响应:这里包括社区搜索的初步答案,以及局部搜索各轮迭代产出的答案。

@step   async def final_answer_generation(self, ctx: Context, ev: FinalAnswer) -> StopEvent:      # Retrieve all intermediate answers collected throughout the search process      intermediate_answers = await ctx.store.get("intermediate_answers")            # Synthesize all findings into a comprehensive final response      answer_prompt = DRIFT_REDUCE_PROMPT.format(          response_type=DEFAULT_RESPONSE_TYPE,          context_data=intermediate_answers,          global_query=ev.query,      )      answer_response = await client.responses.create(          model="gpt-5-mini",          input=[              {"role": "developer", "content": answer_prompt},              {"role": "user", "content": ev.query},          ],          reasoning={"effort": "low"},      )        return StopEvent(result=answer_response.output_text)

总结

DRIFT搜索提供了一个挺有意思的思路,在全局搜索的广度和局部搜索的精度之间找到了平衡。从社区级上下文切入,通过迭代的后续查询逐层下探,既避免了遍历所有社区报告的计算负担,又保证了覆盖面。

这里还有改进空间,比如目前的实现对所有中间答案一视同仁,如果能根据置信度分数做个筛选,最终答案的质量应该会更好,噪声也能降下来。后续查询也可以先按相关性或信息增益排个序,优先追踪最有价值的线索。

另一个值得尝试的方向是加一个查询精炼步骤,用LLM分析所有生成的后续查询,把相似的归并起来避免重复搜索,过滤掉那些大概率没什么收获的查询。这样能大幅减少局部搜索的次数,同时不影响答案质量。

完整代码

https://avoid.overfit.cn/post/5eaca452dcc7422d8c2308586e7cfe56

有兴趣的可以自己跑跑看,或者在这个基础上做些改进。

作者:Tomaz Bratanic