Langchain入门:构建一个基于SQL数据的问答系统

发布于:2025-08-09 ⋅ 阅读:(16) ⋅ 点赞:(0)

架构

从高层次来看,这些系统的步骤是:

  • 将问题转换为DSL查询:模型将用户输入转换为SQL查询。
  • 执行SQL查询:执行查询。
  • 回答问题:模型使用查询结果响应用户输入。

在这里插入图片描述
我们将使用一个OpenAI模型和一个基于FAISS的向量存储。

下面的示例将使用与Chinook数据库的SQLite连接。请按照这些安装步骤在与此笔记本相同的目录中创建Chinook.db:

  • 将此文件保存为Chinook.sql
  • 运行sqlite3 Chinook.db
  • 运行.read Chinook.sql
  • 测试SELECT * FROM Artist LIMIT 10;

现在,Chinhook.db在我们的目录中,我们可以使用SQLAlchemy驱动的SQLDatabase类与之接口:

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

在这里插入图片描述

链(即LangChain 可运行组件的组合)支持步骤可预测的应用程序。我们可以创建一个简单的链,它接受一个问题并执行以下操作:

  • 将问题转换为 SQL 查询;
  • 执行查询;
  • 使用结果回答原始问题。

将问题转换为 SQL 查询

SQL 链或代理的第一步是获取用户输入并将其转换为 SQL 查询。LangChain 提供了一个内置链来实现这一点:create_sql_query_chain。

from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
    openai_api_base = "https://api.siliconflow.cn/v1/",
    openai_api_key = os.environ['siliconflow'],
    model_name = "Qwen/Qwen3-8B",  # 模型名称
)
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there?"})
response

在这里插入图片描述
可以打开LangSmith追踪具体实现流程

执行 SQL 查询

我们可以使用 QuerySQLDatabaseTool 来轻松地将查询执行添加到我们的链中:

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there?"}) 

现在我们有了一种自动生成和执行查询的方法,我们只需将原始问题和 SQL 查询结果结合起来生成最终答案。我们可以通过将问题和结果再次传递给大型语言模型来实现:

这里笔者的结果格式化有异常,多了“SQLQuery:”,因此额外做了sql语句的多余格式化清理

from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

def clean_sql_query(query_result):
    """清理 SQL 查询,移除格式化文本"""
    if isinstance(query_result, str):
        # 移除 "SQLQuery:" 前缀
        if query_result.startswith("SQLQuery:"):
            return query_result.replace("SQLQuery:", "").strip()
        return query_result.strip()
    return query_result

chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | RunnableLambda(clean_sql_query) | execute_query
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)

chain.invoke({"question": "How many employees are there?"})

在这里插入图片描述
让我们回顾一下上面的 LCEL 中发生了什么。假设调用了这个链。

  • 在第一次 RunnablePassthrough.assign 之后,我们有一个包含两个元素的可运行对象: {“question”: question, “query”: write_query.invoke(question)} 在这里,write_query 将生成一个 SQL 查询,以回答问题。
  • 在第二个 RunnablePassthrough.assign 之后,我们添加了第三个元素 “result”,其中包含 execute_query.invoke(query),而 query 是在上一步计算得出的。
  • 这三个输入被格式化为提示词并传递给大型语言模型(LLM)。
  • StrOutputParser() 提取输出消息的字符串内容。

代理

LangChain 有一个 SQL 代理,它提供了一种比链更灵活的与 SQL 数据库交互的方式。使用 SQL 代理的主要优点包括:

  • 它可以根据数据库的架构以及数据库的内容(例如描述特定表)回答问题。
  • 它可以通过运行生成的查询来恢复错误,捕获回溯并正确重新生成。
  • 它可以根据需要多次查询数据库以回答用户问题。
  • 它将通过仅从相关表中检索架构来节省令牌。

要初始化代理,我们将使用 SQLDatabaseToolkit 创建一组工具:

  • 创建和执行查询
  • 检查查询语法
  • 检索表描述
  • … 以及更多
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

tools

在这里插入图片描述

from langchain_core.messages import SystemMessage
SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables.""" 

system_message = SystemMessage(content=SQL_PREFIX)

我们将使用一个预构建的 LangGraph 代理来构建我们的代理

from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, tools, messages_modifier=system_message)

messages_modifier在消息传递给 LLM 之前,对消息列表进行预处理

for s in agent_executor.stream(
    {"messages": [HumanMessage(content="Which country's customers spent the most?")]}
):
    print(s)
    print("------")

请注意,代理会执行多个查询,直到获得所需的信息
在这里插入图片描述

代理同样可以处理定性问题:

for s in agent_executor.stream(
    {"messages": [HumanMessage(content="Describe the playlisttrack table")]}
):
    print(s)
    print("-----")

在这里插入图片描述
为了过滤包含专有名词(如地址、歌曲名称或艺术家)的列,我们首先需要仔细检查拼写,以便正确过滤数据。

我们可以通过创建一个包含数据库中所有不同专有名词的向量存储来实现这一点。然后,每当用户在问题中包含专有名词时,代理可以查询该向量存储,以找到该词的正确拼写。通过这种方式,代理可以确保在构建目标查询之前理解用户所指的实体。

首先,我们需要每个实体的唯一值,为此我们定义一个函数,将结果解析为元素列表:

import ast
import re

def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return list(set(res))

artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]
  • ast.literal_eval() 将字符串转换为实际的 Python 对象
  • r"\b\d+\b" 匹配独立的数字,不匹配嵌入在文字中的数字

在这里插入图片描述
使用这个函数,我们可以创建一个检索器工具,代理可以根据需要执行。

from langchain.agents.agent_toolkits import create_retriever_tool
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings(
        openai_api_base="https://api.siliconflow.cn/v1/",
        openai_api_key=os.environ["siliconFlow"],
        model="Qwen/Qwen3-Embedding-8B")
)
retriever = vector_db.as_retriever(search_kwargs={"k" : 5})
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""

retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)
print(retriever_tool.invoke("Alice Chains"))

在这里插入图片描述
这样,如果代理确定需要根据艺术家写一个过滤器,例如"Alice Chains",它可以首先使用检索器工具观察列的相关值。

将这些结合起来:

system = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

You have access to the following tables: {table_names}

If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool!
Do not try to guess at the proper name - use this function to find similar ones.""".format(
    table_names=db.get_usable_table_names()
)

system_message = SystemMessage(content=system)

tools.append(retriever_tool)

agent = create_react_agent(llm, tools, messages_modifier=system_message)
for s in agent.stream(
    {"messages": [HumanMessage(content="How many albums does alis in chain have?")]}
):
    print(s)
    print("-----")

在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到