基于ReAction范式的问答系统实现demo
参考文档
ReAct论文解读:LLM ReAct范式,在大语言模型中结合推理和动作
说明
由于我最近在做一个基于图数据库的问答系统,所以样例就以查询图数据背景,实现过程仅供参考,希望能够大家带来帮助。
源码
import os
import json
from typing import Generator, Optional, Dict, Any
from neo4j import GraphDatabase
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# ----------------------------
# Neo4j 工具类
# ----------------------------
class Neo4jSearchTool:
def __init__(self):
self.driver = GraphDatabase.driver(
os.getenv("NEO4J_URI"),
auth=(
os.getenv("NEO4J_USER"),
os.getenv("NEO4J_PASSWORD")
)
)
def run(self, query: str) -> str:
try:
with self.driver.session() as session:
result = session.run(query)
data = [dict(record) for record in result]
return json.dumps(data, ensure_ascii=False) if data else "[]"
except Exception as e:
return f"ERROR: {str(e)}"
class Neo4jSchemaTool:
def __init__(self, driver):
self.driver = driver
def get_node_schema(self, session):
q = """
CALL db.schema.nodeTypeProperties()
YIELD nodeType, propertyName, propertyTypes
RETURN nodeType, propertyName, propertyTypes
"""
schema = {}
for rec in session.run(q):
label = rec["nodeType"].strip(":`")
prop = rec["propertyName"]
types = ", ".join(rec["propertyTypes"]) or "Unknown"
schema.setdefault(label, {})[prop] = types
return schema
# ----------------------------------------------------------------------
def get_relationship_schema(self, session):
"""
For each relType: collect property definitions + a sampled (srcLabel, tgtLabel).
"""
# 1) property map
q_props = """
CALL db.schema.relTypeProperties()
YIELD relType, propertyName, propertyTypes
RETURN relType, propertyName, propertyTypes
"""
rel_schema = {}
for rec in session.run(q_props):
rtype = rec["relType"].strip(":`")
prop = rec["propertyName"]
if prop:
types = ", ".join(rec["propertyTypes"]) or "Unknown"
rel_schema.setdefault(rtype, {})[prop] = types
# 2) sample endpoints for each relationship type
for rtype in rel_schema:
q_sample = f"""
MATCH (s)-[r:`{rtype}`]->(t)
WITH labels(s)[0] AS src, labels(t)[0] AS tgt
RETURN src, tgt LIMIT 1
"""
rec = session.run(q_sample).single()
if rec:
rel_schema[rtype]["_endpoints"] = [rec["src"], rec["tgt"]]
else: # no relationship instance found
rel_schema[rtype]["_endpoints"] = ["Unknown", "Unknown"]
return rel_schema
def get_schema(self) -> dict:
"""提取数据库中的所有标签、关系和属性"""
with self.driver.session() as session:
# 获取所有节点标签
labels = self.get_node_schema(session)
rel_types = self.get_relationship_schema(session)
return {
"NodeTypes": labels,
"RelationshipTypes": rel_types
}
def format_schema_prompt(self) -> str:
"""将schema转换为自然语言描述"""
schema = self.get_schema()
prompt = "数据库包含以下结构:\n"
# 标签和属性
prompt += "## 节点类型\n"
prompt += json.dumps(schema["NodeTypes"],ensure_ascii=False)
# 关系
prompt += "\n## 关系类型\n"
prompt += json.dumps(schema["RelationshipTypes"],ensure_ascii=False)
return prompt
class AnswerValidator:
@staticmethod
def is_valid_answer(observation: str) -> bool:
"""检查工具返回是否包含有效答案"""
if observation.startswith("ERROR") or observation == "[]":
return False
try:
data = json.loads(observation)
if isinstance(data, list) and len(data) > 0:
first_item = data[0]
# 检查是否有非空值
return any(v for v in first_item.values() if v not in [None, ""])
return False
except:
return False
@staticmethod
def should_terminate(llm_response: str) -> bool:
"""通过LLM判断是否应该终止"""
prompt = f"""判断以下模型响应是否包含最终答案:
响应内容:{llm_response}
只需返回true或false:"""
response = OpenAI().chat.completions.create(
model="deepseek-chat",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
return "true" in response.choices[0].message.content.lower()
# ----------------------------
# ReAct 引擎
# ----------------------------
class ReActQASystem:
def __init__(self):
self.llm = OpenAI()
neo4j_driver = GraphDatabase.driver(
os.getenv("NEO4J_URI"),
auth=(os.getenv("NEO4J_USER"), os.getenv("NEO4J_PASSWORD"))
)
self.tools = {
"neo4j_search": Neo4jSearchTool(),
"get_schema": Neo4jSchemaTool(neo4j_driver)
}
self.schema_prompt = self.tools["get_schema"].format_schema_prompt()
self.max_steps = 5
def _build_prompt(self, query: str, scratchpad: str = "") -> str:
base_prompt = f"""你是一个审计专家,需要根据数据库结构编写准确的Cypher查询。
{self.schema_prompt}
可用工具:
- neo4j_search: 执行Cypher查询,输入应为JSON格式的{{"query": "MATCH..."}}
当前问题:{query}
历史步骤:
{scratchpad}
严格按格式响应:
Thought: 分析问题并确认需要查询的标签和关系
Action:
```json
{{"action": "工具名", "action_input": {{...}}}}
```"""
return base_prompt
def execute(self, query: str) -> Generator[str, None, None]:
scratchpad = ""
for step in range(self.max_steps):
# 调用LLM生成响应
prompt = self._build_prompt(query, scratchpad)
print(f"LLM prompt: {prompt}")
response = self.llm.chat.completions.create(
model="deepseek-chat",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
content = response.choices[0].message.content
print(f"LLM Response: {content}")
print(f"================================================")
# 解析响应
thought, action = self._parse_response(content)
scratchpad += f"\n{content}\n"
if not action:
yield f"Final Answer: {thought}"
break
# 执行工具调用
tool_name = action["action"]
if AnswerValidator.should_terminate(action["action_input"]):
yield f"Final Answer: {action['action_input']}"
break
elif tool_name in self.tools:
tool_result = self.tools[tool_name].run(action["action_input"]["query"])
observation = f"Observation: {tool_result}"
scratchpad += observation + "\n"
yield observation
else:
yield f"ERROR: 未知工具 {tool_name}"
def _parse_response(self, text: str) -> tuple[str, Optional[Dict]]:
thought = ""
action = None
# 提取Thought部分
thought_start = text.find("Thought:") + len("Thought:")
thought_end = text.find("Action:")
if thought_start >= 0 and thought_end >= 0:
thought = text[thought_start:thought_end].strip()
# 提取Action部分
action_start = text.find("```json") + len("```json")
action_end = text.find("```", action_start)
if action_start >= 0 and action_end >= 0:
try:
action = json.loads(text[action_start:action_end].strip())
except json.JSONDecodeError:
pass
return thought, action
# ----------------------------
# 主程序
# ----------------------------
def main():
qa_system = ReActQASystem()
print("审计问答系统已启动(输入quit退出)")
while True:
query = input("\n用户提问: ")
if query.lower() == "quit":
break
print("\n系统响应:")
for response in qa_system.execute(query):
print(response)
if __name__ == "__main__":
main()
总结
欢迎大家留言,讨论