LangGraph--设计一个给出标准提示词模板的聊天机器人

发布于:2025-06-16 ⋅ 阅读:(21) ⋅ 点赞:(0)

创建一个聊天机器人,帮助用户生成提示。它将首先收集用户的需求,然后生成提示(并根据用户输入进行优化)。这些功能被分为两个独立的状态,而 LLM 决定何时在这两个状态之间切换。

测试结果如下:

 

from typing import Annotated,List,Literal
from typing_extensions import TypedDict
from langchain_deepseek import ChatDeepSeek
from langchain_tavily import TavilySearch
from langchain_core.messages import BaseMessage,SystemMessage,AIMessage,HumanMessage,ToolMessage
from typing_extensions import TypedDict

from langgraph.graph import StateGraph,START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver

from langgraph.types import Command, interrupt
from pydantic import BaseModel

import os 
from dotenv import load_dotenv

# 加载.env文件中的环境变量
load_dotenv()


template = """你的任务是从用户那里获取他们想要创建哪种提示模板的信息。

你应该从他们那里获取以下信息:

- 提示的目标是什么
- 哪些变量将被传递到提示模板中
- 输出不应该做什么的任何限制
- 输出必须遵守的任何要求

如果你无法辨别这些信息,请要求他们澄清!不要试图胡乱猜测。

在你能辨别所有信息后,调用相关工具。"""

def get_messages_info(messages):
    return [SystemMessage(content=template)]+messages

class PromptInstructions(BaseModel):
    """关于如何提示LLM的说明"""
    objective:str
    variables:List[str]
    constraints:List[str]
    requirements:List[str]

llm = ChatDeepSeek(
    model="deepseek-chat",
    api_key=os.getenv("DEEPSEEK_API_KEY"))

llm_with_tool = llm.bind_tools([PromptInstructions])

def info_chain(state):
    messages = get_messages_info(state["messages"])
    response = llm_with_tool.invoke(messages)
    return {"messages": [response]}


# 现在设置将生成提示的状态。这需要单独的系统消息,以及一个函数来过滤掉所有工具调用之前的消息(因为那是前一个状态决定生成提示的时候)
prompt_system = """根据以下要求,编写一个良好的提示模板:

{reqs} """

# 获取提示消息的函数  
# 仅获取工具调用后的消息
def get_prompt_messages(messages:list):
    tool_call = None
    other_msgs=[]
    for msg in messages:
        if isinstance(msg, AIMessage) and msg.tool_calls:
            tool_call = msg.tool_calls[0]["args"]
        elif isinstance(msg,ToolMessage):
            continue
        elif tool_call is not None:
            other_msgs.append(msg)
    return [SystemMessage(content=prompt_system.format(reqs=tool_call))]+other_msgs

def prompt_gen_chain(state):
    messages = get_prompt_messages(state["messages"])
    response = llm.invoke(messages)
    return {"messages": [response]}

# 定义状态逻辑
# 聊天机器人所处状态的逻辑。如果最后一条消息是工具调用,那么我们处于"提示创建者"( prompt )应该回应的状态。
# 否则,如果最后一条消息不是 HumanMessage,那么我们知道人类应该下一个回应,所以我们处于 END 状态。
# 如果最后一条消息是 HumanMessage,那么如果之前有工具调用,我们处于 prompt 状态。否则,我们处于"信息收集"( info )状态。
def get_state(state):
    messages = state["messages"]
    if(isinstance(messages[-1],AIMessage) and messages[-1].tool_calls):
        return "add_tool_messages"
    elif not isinstance(messages[-1],HumanMessage):
        return END
    return "info"

# 创建图
# 现在可以创建这个图了。我们将使用一个 SqliteSaver 来持久化对话历史。
class State(TypedDict):
    messages:Annotated[List,add_messages]

memory = MemorySaver()
workflow = StateGraph(State)
workflow.add_node("info",info_chain)
workflow.add_node("prompt",prompt_gen_chain)

@workflow.add_node
def add_tool_message(state:State):
    return {
        "messages":[
            ToolMessage(
                content="Prompt generated!",
                tool_call_id = state["messages"][-1].too_calls[0]["id"],
                )       
            ]
    }

workflow.add_conditional_edges("info",get_state, ["add_tool_message","info",END])
workflow.add_edge("add_tool_message","prompt")
workflow.add_edge("prompt",END)
workflow.add_edge(START,"info")
graph = workflow.compile(checkpointer=memory)

graph_png = graph.get_graph().draw_mermaid_png()
with open("prompt_workflow.png", "wb") as f:
    f.write(graph_png)


# 使用图
# 现在我们可以使用创建的聊天机器人了。
import uuid

cached_human_responses = ["你好!", "rag 提示词", "1 rag, 2 none, 3 no, 4 no", "red", "q"]
cached_response_index = 0
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
while True:
    try:
        user = input("User (q/Q to quit): ")
    except:
        user = cached_human_responses[cached_response_index]
        cached_response_index += 1
    # print(f"User (q/Q to quit): {user}")
    if user in {"q", "Q"}:
        print("AI: Byebye")
        break
    output = None
    for output in graph.stream(
        {"messages": [HumanMessage(content=user)]}, config=config, stream_mode="updates"
    ):
        last_message = next(iter(output.values()))["messages"][-1]
        last_message.pretty_print()

    if output and "prompt" in output:
        print("Done!")


网站公告

今日签到

点亮在社区的每一天
去签到