创建一个聊天机器人,帮助用户生成提示。它将首先收集用户的需求,然后生成提示(并根据用户输入进行优化)。这些功能被分为两个独立的状态,而 LLM 决定何时在这两个状态之间切换。
测试结果如下:
from typing import Annotated,List,Literal
from typing_extensions import TypedDict
from langchain_deepseek import ChatDeepSeek
from langchain_tavily import TavilySearch
from langchain_core.messages import BaseMessage,SystemMessage,AIMessage,HumanMessage,ToolMessage
from typing_extensions import TypedDict
from langgraph.graph import StateGraph,START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.types import Command, interrupt
from pydantic import BaseModel
import os
from dotenv import load_dotenv
# 加载.env文件中的环境变量
load_dotenv()
template = """你的任务是从用户那里获取他们想要创建哪种提示模板的信息。
你应该从他们那里获取以下信息:
- 提示的目标是什么
- 哪些变量将被传递到提示模板中
- 输出不应该做什么的任何限制
- 输出必须遵守的任何要求
如果你无法辨别这些信息,请要求他们澄清!不要试图胡乱猜测。
在你能辨别所有信息后,调用相关工具。"""
def get_messages_info(messages):
return [SystemMessage(content=template)]+messages
class PromptInstructions(BaseModel):
"""关于如何提示LLM的说明"""
objective:str
variables:List[str]
constraints:List[str]
requirements:List[str]
llm = ChatDeepSeek(
model="deepseek-chat",
api_key=os.getenv("DEEPSEEK_API_KEY"))
llm_with_tool = llm.bind_tools([PromptInstructions])
def info_chain(state):
messages = get_messages_info(state["messages"])
response = llm_with_tool.invoke(messages)
return {"messages": [response]}
# 现在设置将生成提示的状态。这需要单独的系统消息,以及一个函数来过滤掉所有工具调用之前的消息(因为那是前一个状态决定生成提示的时候)
prompt_system = """根据以下要求,编写一个良好的提示模板:
{reqs} """
# 获取提示消息的函数
# 仅获取工具调用后的消息
def get_prompt_messages(messages:list):
tool_call = None
other_msgs=[]
for msg in messages:
if isinstance(msg, AIMessage) and msg.tool_calls:
tool_call = msg.tool_calls[0]["args"]
elif isinstance(msg,ToolMessage):
continue
elif tool_call is not None:
other_msgs.append(msg)
return [SystemMessage(content=prompt_system.format(reqs=tool_call))]+other_msgs
def prompt_gen_chain(state):
messages = get_prompt_messages(state["messages"])
response = llm.invoke(messages)
return {"messages": [response]}
# 定义状态逻辑
# 聊天机器人所处状态的逻辑。如果最后一条消息是工具调用,那么我们处于"提示创建者"( prompt )应该回应的状态。
# 否则,如果最后一条消息不是 HumanMessage,那么我们知道人类应该下一个回应,所以我们处于 END 状态。
# 如果最后一条消息是 HumanMessage,那么如果之前有工具调用,我们处于 prompt 状态。否则,我们处于"信息收集"( info )状态。
def get_state(state):
messages = state["messages"]
if(isinstance(messages[-1],AIMessage) and messages[-1].tool_calls):
return "add_tool_messages"
elif not isinstance(messages[-1],HumanMessage):
return END
return "info"
# 创建图
# 现在可以创建这个图了。我们将使用一个 SqliteSaver 来持久化对话历史。
class State(TypedDict):
messages:Annotated[List,add_messages]
memory = MemorySaver()
workflow = StateGraph(State)
workflow.add_node("info",info_chain)
workflow.add_node("prompt",prompt_gen_chain)
@workflow.add_node
def add_tool_message(state:State):
return {
"messages":[
ToolMessage(
content="Prompt generated!",
tool_call_id = state["messages"][-1].too_calls[0]["id"],
)
]
}
workflow.add_conditional_edges("info",get_state, ["add_tool_message","info",END])
workflow.add_edge("add_tool_message","prompt")
workflow.add_edge("prompt",END)
workflow.add_edge(START,"info")
graph = workflow.compile(checkpointer=memory)
graph_png = graph.get_graph().draw_mermaid_png()
with open("prompt_workflow.png", "wb") as f:
f.write(graph_png)
# 使用图
# 现在我们可以使用创建的聊天机器人了。
import uuid
cached_human_responses = ["你好!", "rag 提示词", "1 rag, 2 none, 3 no, 4 no", "red", "q"]
cached_response_index = 0
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
while True:
try:
user = input("User (q/Q to quit): ")
except:
user = cached_human_responses[cached_response_index]
cached_response_index += 1
# print(f"User (q/Q to quit): {user}")
if user in {"q", "Q"}:
print("AI: Byebye")
break
output = None
for output in graph.stream(
{"messages": [HumanMessage(content=user)]}, config=config, stream_mode="updates"
):
last_message = next(iter(output.values()))["messages"][-1]
last_message.pretty_print()
if output and "prompt" in output:
print("Done!")