构建智能 SQL 查询代理agent,把整个查询过程模块化,既能自动判断使用哪些表,又能自动生成 SQL 语句,最终返回查询结果

发布于:2025-03-03 ⋅ 阅读:(19) ⋅ 点赞:(0)

示例代码:

import os
import getpass
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from pprint import pprint
from pydantic import BaseModel
from langchain_core.tools import tool
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities import SQLDatabase

# 定义用于提取表类别的 Pydantic 模型
class Table(BaseModel):
    name: str

# 定义一个映射函数,将类别名称转换为具体的 SQL 表名列表
def get_tables(categories: List[Table]) -> List[str]:
    """根据类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend([
                "Album",
                "Artist",
                "Genre",
                "MediaType",
                "Playlist",
                "PlaylistTrack",
                "Track",
            ])
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

class ChinookSQLAgent:
    """
    一个专门用于 Chinook SQL 数据库查询的 agent,
    利用 LLM 解析用户的问题,自动判断与问题相关的表类别,
    并生成相应的 SQL 查询执行。
    
    属性:
        sql_agent_llm: 用于解析问题和生成 SQL 查询的 LLM 模型。
        db: Chinook 数据库的连接对象。
        full_chain: 一个链条,将用户问题转为 SQL 查询后执行。
    
    构造方法参数:
        sqldb_directory (str): Chinook SQLite 数据库文件所在的目录路径。
        llm (str): LLM 模型名称(例如 "gpt-3.5-turbo"),但内部使用 "llama3-70b-8192"。
        llm_temperature (float): LLM 的温度参数,用于控制生成结果的随机性。
    """
    def __init__(self, sqldb_directory: str, llm: str, llm_temperature: float) -> None:
        # 初始化 LLM 模型(此处使用 "llama3-70b-8192",由 groq 提供)
        self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temperature)
        
        # 建立到 Chinook SQLite 数据库的连接
        self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
        print("可用表:", self.db.get_usable_table_names())
        
        # 定义系统提示,指导 LLM 根据用户问题返回相关的表类别
        category_chain_system = (
            "Return the names of the SQL tables that are relevant to the user question. "
            "The tables are:\n\nMusic\nBusiness"
        )
        # 创建提取链,从用户问题中提取表类别(使用 Pydantic 模型 Table)
        category_chain = create_extraction_chain_pydantic(Table, self.sql_agent_llm, system_message=category_chain_system)
        # 将提取到的类别转换为具体的 SQL 表名
        table_chain = category_chain | get_tables
        
        # 定义自定义 SQL 提示模板
        custom_prompt = PromptTemplate(
            input_variables=["dialect", "input", "table_info", "top_k"],
            template=(
                "You are a SQL expert using {dialect}.\n"
                "Given the following table schema:\n"
                "{table_info}\n"
                "Generate a syntactically correct SQL query to answer the question: \"{input}\".\n"
                "Don't limit the results to {top_k} rows.\n"
                "Ensure the query uses DISTINCT to avoid duplicate rows.\n"
                "Return only the SQL query without any additional commentary or Markdown formatting."
            )
        )
        # 利用自定义提示模板创建 SQL 查询链
        query_chain = create_sql_query_chain(self.sql_agent_llm, self.db, prompt=custom_prompt)
        
        # 将输入中的 "question" 键转换为 table_chain 所需的 "input" 键
        table_chain = {"input": itemgetter("question")} | table_chain
        
        # 利用 RunnablePassthrough.assign 将提取到的 table_names 注入上下文,然后通过管道传递给 SQL 查询链
        self.full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

    def run(self, query: str) -> str:
        """
        接收用户的查询,将问题转化为 SQL 查询语句,然后在 Chinook 数据库中执行。
        
        参数:
            query (str): 用户的自然语言查询,例如 "What are all the genres of Alanis Morisette songs? Do not repeat!"
        
        返回:
            str: 执行 SQL 查询后的结果
        """
        # 调用完整的链条生成 SQL 查询语句
        sql_query = self.full_chain.invoke({"question": query})
        # 执行生成的 SQL 查询并返回结果
        return self.db.run(sql_query)

# 使用 @tool 装饰器将查询功能暴露为一个工具
@tool
def query_chinook_sqldb(query: str) -> str:
    """
    查询 Chinook SQL 数据库的工具函数,输入为用户的查询语句。
    
    该函数会实例化一个 ChinookSQLAgent,然后调用其 run 方法来处理查询。
    """
    # 注意:sqldb_directory 应该是数据库文件所在路径,这里假设它是全局变量或配置项
    sqldb_directory = here("data/Chinook.db")
    agent = ChinookSQLAgent(
        sqldb_directory=sqldb_directory,  # 如 TOOLS_CFG.chinook_sqldb_directory
        llm="llama3-70b-8192",              # 如 TOOLS_CFG.chinook_sqlagent_llm
        llm_temperature=0
    )
    return agent.run(query)


query_chinook_sqldb('What are all the genres of Alanis Morisette songs')

输出结果:
在这里插入图片描述
下面我来用通俗易懂的话解释这段代码的整体作用和每个部分的含义,同时举例说明输入输出情况。


1. 引入依赖和定义工具

代码最开始导入了各种模块和工具,这里主要用到的是 LangChain 相关的模块,比如:

  • Pydantic:用于定义数据模型(这里定义了一个 Table 模型,它只有一个字段 name,用来存储表类别名称)。
  • langchain_core 和 langchain.chains:提供了链(Chain)相关的操作,让我们可以将多个步骤拼接起来,比如从问题中提取信息、生成 SQL 查询等。
  • pyprojroot.here:用于定位当前项目目录下的文件路径,方便定位数据库文件。

2. 定义数据模型和映射函数

数据模型 Table

class Table(BaseModel):
    name: str
  • 作用
    定义了一个 Pydantic 模型 Table,用来描述从用户问题中提取出来的表类别。这个模型只有一个字段 name,比如可能返回 "Music""Business"

映射函数 get_tables

def get_tables(categories: List[Table]) -> List[str]:
    """根据类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend([
                "Album",
                "Artist",
                "Genre",
                "MediaType",
                "Playlist",
                "PlaylistTrack",
                "Track",
            ])
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables
  • 作用
    接收一个 Table 对象列表,然后根据类别名称返回实际在 Chinook 数据库中使用的表名列表。
  • 举例说明
    • 如果提取结果是 [Table(name="Music")],那么函数返回的表名列表就是
      ["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]
    • 如果类别是 "Business",则返回商业相关的表名列表。

3. 定义 ChinookSQLAgent 类

这个类封装了整个从自然语言问题到生成并执行 SQL 查询的流程。

3.1 初始化方法 init

def __init__(self, sqldb_directory: str, llm: str, llm_temperature: float) -> None:
    # 1. 初始化 LLM 模型
    self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temperature)
    
    # 2. 连接到 Chinook 数据库(SQLite)
    self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
    print("可用表:", self.db.get_usable_table_names())
  • 作用
    • 利用 init_chat_model 初始化语言模型(例如这里传入 "llama3-70b-8192"),用来解析用户问题和生成 SQL 查询。
    • 通过 SQLDatabase.from_uri 连接到 Chinook 数据库,并打印出数据库中可用的表(例如:Album、Artist、Customer、Employee 等)。

3.2 创建提取表类别的链

# 定义系统提示,告诉 LLM 只返回 "Music" 或 "Business" 两类
category_chain_system = (
    "Return the names of the SQL tables that are relevant to the user question. "
    "The tables are:\n\nMusic\nBusiness"
)
# 创建提取链(利用 Pydantic 模型 Table),从用户问题中提取出相关的表类别
category_chain = create_extraction_chain_pydantic(Table, self.sql_agent_llm, system_message=category_chain_system)
# 将提取出的类别映射为具体的 SQL 表名
table_chain = category_chain | get_tables
  • 作用

    • 定义一个系统提示(system message),指导 LLM 只考虑 “Music” 和 “Business” 两个类别。
    • 通过 create_extraction_chain_pydantic 创建一个链,自动从用户问题中提取出一个或多个 Table 对象。
    • 利用管道操作符 | 把提取出的结果传递给 get_tables 函数,得到实际的 SQL 表名列表。
  • 举例

    • 用户问题“哪些表中存储了 Alanis Morisette 歌曲信息?”会被 LLM 分析后返回 [Table(name="Music")],进而映射为音乐相关的所有表名。

3.3 定义自定义 SQL 提示模板

custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template=(
        "You are a SQL expert using {dialect}.\n"
        "Given the following table schema:\n"
        "{table_info}\n"
        "Generate a syntactically correct SQL query to answer the question: \"{input}\".\n"
        "Don't limit the results to {top_k} rows.\n"
        "Ensure the query uses DISTINCT to avoid duplicate rows.\n"
        "Return only the SQL query without any additional commentary or Markdown formatting."
    )
)
  • 作用
    • 定义了一个提示模板,让 LLM 生成 SQL 查询时遵循固定的格式。
    • 模板中说明:
      • 你是一个使用特定 SQL 方言({dialect})的 SQL 专家。
      • 根据给定的数据库表结构({table_info})和用户问题({input}),生成一条正确的 SQL 查询。
      • 不要限制返回行数({top_k}仅作为参考),并且必须使用 DISTINCT 去除重复行。
    • 这样就能让自然语言生成的 SQL 语句在逻辑上避免重复数据的问题,而无需后期修改生成的 SQL。

3.4 创建 SQL 查询链和组合完整链条

# 利用自定义提示模板创建 SQL 查询链
query_chain = create_sql_query_chain(self.sql_agent_llm, self.db, prompt=custom_prompt)

# 将输入中的 "question" 键转换为 table_chain 需要的 "input" 键
table_chain = {"input": itemgetter("question")} | table_chain

# 利用 RunnablePassthrough.assign 将提取到的 table_names 注入上下文,接着传递给 SQL 查询链
self.full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain
  • 作用
    • 用刚才定义的 custom_prompt 和数据库信息,创建一个 SQL 查询链,用于将自然语言问题转换成 SQL 查询语句。
    • 为了保证输入格式一致,将用户输入中的 question 键转换为 input 键(因为之前的链条是根据 input 来工作的)。
    • 最后利用 RunnablePassthrough.assign 将前面提取到的表名列表(table_names_to_use)注入到上下文中,并与 SQL 查询链拼接起来,构成一个完整的处理链。这条链会先从问题中提取出使用哪些表,然后再生成 SQL 语句。

3.5 run 方法

def run(self, query: str) -> str:
    """
    接收用户的查询,将问题转化为 SQL 查询语句,然后在 Chinook 数据库中执行。
    """
    # 调用完整链条生成 SQL 查询语句
    sql_query = self.full_chain.invoke({"question": query})
    # 执行生成的 SQL 查询并返回结果
    return self.db.run(sql_query)
  • 作用
    • 接受用户传入的自然语言查询(例如“What are all the genres of Alanis Morisette songs”)。
    • 通过调用完整链条(self.full_chain)将该查询转成 SQL 查询语句。
    • 最后在数据库中执行该 SQL 查询,并返回结果。

4. 将 agent 以工具形式暴露

@tool
def query_chinook_sqldb(query: str) -> str:
    """
    查询 Chinook SQL 数据库的工具函数,输入为用户的查询语句。
    """
    # 定位数据库文件,使用 pyprojroot 的 here 函数查找路径
    sqldb_directory = here("data/Chinook.db")
    # 实例化一个 ChinookSQLAgent
    agent = ChinookSQLAgent(
        sqldb_directory=sqldb_directory,  
        llm="llama3-70b-8192",
        llm_temperature=0
    )
    # 调用 agent 的 run 方法执行查询,并返回结果
    return agent.run(query)
  • 作用
    • 使用 @tool 装饰器将函数暴露为一个工具(比如在其他系统中可以直接调用)。
    • 函数内部通过 here("data/Chinook.db") 定位数据库文件路径,然后创建一个 ChinookSQLAgent 实例。
    • 最后调用 agent.run 方法来处理用户的查询并返回最终的查询结果。

5. 举例说明

假设用户调用如下命令:

query_chinook_sqldb('What are all the genres of Alanis Morisette songs')

整个流程如下:

  1. 输入:用户输入问题“What are all the genres of Alanis Morisette songs”。

  2. 提取表类别

    • LLM 根据预定义的系统提示分析问题,判断这个问题涉及音乐数据,所以会提取出类别 "Music"
    • 映射函数 get_tables"Music" 映射为相关的表名列表:["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]
  3. 生成 SQL 查询

    • 根据数据库的 schema(表结构信息)和自定义的 SQL 提示模板,LLM 会生成一条 SQL 查询语句。
    • 提示模板中要求生成的查询语句必须使用 DISTINCT 去除重复记录,因此生成的 SQL 可能类似下面这样:
      SELECT DISTINCT Genre.Name
      FROM Track
      JOIN Genre ON Track.GenreId = Genre.GenreId
      JOIN Artist ON Track.ArtistId = Artist.ArtistId
      WHERE Artist.Name LIKE '%Alanis Morisette%'
      
      (实际生成的语句可能会根据数据库结构略有不同)
  4. 执行查询

    • 生成的 SQL 查询语句通过 self.db.run(sql_query) 在 Chinook 数据库中执行,并返回查询结果。
  5. 输出

    • 最终,用户获得查询结果,比如数据库中所有不重复的音乐流派名称。

总结

这段代码的整体流程就是:

  • 利用 LLM 根据自然语言问题判断需要查询的表类别,
  • 将类别映射成 Chinook 数据库中实际的表名,
  • 利用自定义提示模板(要求生成的查询中使用 DISTINCT 去重)生成 SQL 查询语句,
  • 执行 SQL 查询并返回结果。

这种设计把整个查询过程模块化,既能自动判断使用哪些表,又能自动生成 SQL 语句,适合构建智能 SQL 查询代理。