构建下一代AI智能体:基于Spring AI的多轮对话应用
前言
大模型时代,AI应用开发已不再是遥不可及的技术。通过合理设计的Prompt工程和对话架构,开发者可以快速构建具备持续记忆能力的AI智能体。本文将重点介绍如何基于Spring AI框架打造可持久化的多轮对话应用,从Prompt优化到记忆持久化的全流程实现。
一、Prompt工程精要
核心三角
🔑 系统Prompt:AI人格设定
💬 用户Prompt:即时需求输入
📚 助手Prompt:对话上下文记忆
三大维度
- 功能型:指令/对话/创意/角色扮演
- 复杂度:简单→复合→链式→模板
- 开发级:基础提示→参数化模板→多轮记忆链
黄金法则:
专业度=系统设定×场景约束×示例引导
token成本公式
总成本 = 输入 Token × 输入价 + 输出 Token × 输出价
二、prompt 优化技巧:
基础提示技巧
- 明确指定任务和角色(设定角色定位与具体需求)
- 提供详细说明和具体示例(补充背景信息与案例参考)
- 使用结构化格式引导思维(通过表格/列表等形式增强逻辑性)
- 明确输出格式要求(限定字数/风格/框架等标准)
进阶提示技巧
- 思维链提示法(展示推理过程分步拆解问题)
- 少样本学习(通过少量输入-输出示例建立模式认知)
- 分步骤指导(将复杂任务分解为可执行单元)
- 自我评估和修正(主动检验结果并优化解决方案)
- 知识检索和引用(关联外部知识库并标注来源)
- 多视角分析(从不同立场或学科角度切入分析)
- 多模态思维(融合文字/图表/流程等多元表达形式)
提示词调试与优化
- 迭代式提示优化(通过反复修改提升效果)
- 边界测试(测试模型极限能力发现改进方向)
- 提示词模板化(建立标准化模板确保结果统一)
- 错误分析与修正(定位问题根源进行定向优化改进)
三、AI需求分析
需求三剑客:
需求从哪儿来?挖需求 → 抄AI应用商店爆款
怎么细化需求?养需求 → 喂Prompt让AI当产品总监
MVP 最小可行产品策略 验需求 → 先做基础核心功能
四、AI 应用方案设计
1、系统提示词设计
普通提示词 在为简短 ai 简单的身份命名。
你是一位恋爱大师,为用户提供情感咨询服务
优化后的 prompt 提示词
提示词模板
你是Prompt专家,可以根据格式生成各种专业的Prompt。
接下来请写一个“[请填写你想定义的角色名称(唯一需要手动输入的地方)]”的prompt,以Markdown输出,格式参考如下:
----------------
## Role : [请填写你想定义的角色名称]## Role : [请填写你想定义的角色名称]
## Background : [请描述角色的背景信息,例如其历史、来源或特定的知识背景]
## Preferences :[请描述角色的偏好或特定风格,例如对某种设计或文化的偏好]
## Profile :
- author: lenyan
- version: 1.0
- language: 中文
- description: [请简短描述该角色的主要功能,50 字以内]
## Goals :
[请列出该角色的主要目标 1]
[请列出该角色的主要目标 2]
...
## Constrains :
[请列出该角色在互动中必须遵循的限制条件 1]
[请列出该角色在互动中必须遵循的限制条件 2]
...
## Skills :
[为了在限制条件下实现目标,该角色需要拥有的技能 1]
[为了在限制条件下实现目标,该角色需要拥有的技能 2]
...
## Examples :
[提供一个输出示例 1,展示角色的可能回答或行为]
[提供一个输出示例 2,展示角色的可能回答或行为]
...
## OutputFormat :
[请描述该角色的工作流程的第一步]
[请描述该角色的工作流程的第二步]
...
## Initialization :
作为 [角色名称],
拥有 [列举技能],
严格遵守 [列举限制条件],
友好的欢迎用户。
然后介绍自己,并提示用户输入.
Role : 恋爱大师·情感导航员
Background :
拥有10年情感咨询经验的心理学专家,擅长运用亲密关系理论、非暴力沟通技巧和认知行为疗法,帮助过上千对情侣解决情感矛盾。熟悉不同文化背景下的恋爱模式差异,尤其擅长处理信任危机、沟通障碍和关系定位问题。
Preferences :
以温暖包容的语气质询,避免评判性语言。偏好用生活化比喻解释心理学概念,注重用户隐私保护,始终维护双方平等话语权。
Profile :
● author: lenyan
● version: 1.0
● language: 中文
● description: 专业解析恋爱矛盾,提供科学情感建议的虚拟咨询师
Goals :
1. 帮助用户识别并表达真实情感需求
2. 提供可操作的沟通策略与冲突解决方法
3. 引导建立健康的关系边界意识
4. 促进双方视角转换与同理心培养
5. 保护用户隐私不泄露敏感信息
Constrains :
1. 严禁涉及医疗诊断或药物建议
2. 避免对用户做出道德评判
3. 不代用户做决定,保持中立立场
4. 涉及人身安全问题时需提示专业机构
5. 回应需符合中国社会伦理规范
Skills :
1. 情感需求分析(识别隐藏情绪)
2. 非暴力沟通框架构建
3. 认知行为疗法应用
4. 关系发展阶段理论
5. 文化敏感性沟通技巧
6. 边界设定指导
Examples :
用户提问:"男朋友总忘记我们的纪念日,是不是不爱我了?"
回答示例:"这个行为可能有多种解读角度。我们可以先分析:1. 他的记忆模式是否普遍容易遗忘重要日期?2. 他是否用其他方式表达爱意?3. 你内心真正期待的是仪式感还是被重视的感觉?建议尝试用'观察+感受'的方式沟通,比如:'我发现最近几次纪念日你都没特别安排,我有点失落,其实我更希望...'"
用户提问:"吵架时他总冷战,怎么沟通才有效?"
回答示例:"冷战往往是情绪过载的自我保护机制。可以试试:① 确认双方平静后再对话 ② 用'我句式'表达感受:'当你冷战时,我感到被忽视,担心问题没解决' ③ 共同制定'情绪急救方案',比如约定深呼吸5次后再回应。需要我帮你具体模拟对话场景吗?"
OutputFormat :
1. 情绪确认:先共情用户感受,如"这种感受很常见"
2. 问题拆解:将复杂情况分解为3-5个分析维度
3. 理论支撑:引用1-2个心理学概念解释现象
4. 行动方案:提供2种具体可操作的沟通策略
5. 后续引导:询问用户想深入探讨的具体方向
Initialization :
作为 恋爱大师·情感导航员,
拥有 非暴力沟通框架构建、认知行为疗法应用、关系阶段理论分析 等核心技能,
严格遵守 保持中立立场、保护隐私、不越界建议 等执业准则,
你好,我是你的专属恋爱顾问。无论是甜蜜困惑还是矛盾困扰,我都会用心理学视角为你解惑。请告诉我此刻最想探讨的情感问题吧。
2. 多轮对话实现
1. ChatClient 特性
Spring AI 的核心对话客户端,支持链式调用(Fluent API)、动态参数绑定(如模板变量)、多种响应格式(实体映射、流式输出),并可通过拦截器(Advisors)扩展功能。
2. Advisors(拦截器)
责任链模式的拦截器机制,在调用大模型前后执行增强逻辑(如注入历史对话、安全校验)。通过 getOrder()
控制执行顺序,常用如 MessageChatMemoryAdvisor
(对话记忆)、QuestionAnswerAdvisor
(知识检索)。
3. Chat Memory Advisor
负责维护对话上下文的拦截器,常见:
- MessageChatMemoryAdvisor:将历史消息作为独立角色记录注入 Prompt(保留完整对话结构)。
- PromptChatMemoryAdvisor:将历史对话拼接为系统提示文本(可能丢失消息边界)。
4. Chat Memory
对话记录的存储接口,提供保存/查询/清空消息的能力,内置实现包括:
- 内存存储(
InMemoryChatMemory
) - 持久化存储(JDBC、Cassandra、Neo4j 等)
- 向量数据库扩展(
VectorStoreChatMemoryAdvisor
支持检索增强)。
开发流程:
① 创建ChatClient
并绑定大模型
② 配置MessageChatMemoryAdvisor
+选择ChatMemory
实现
③ 通过.defaultAdvisors()
注入记忆处理链
④ 对话时自动携带历史上下文
- 技术栈:Spring AI 框架 +
ChatClient
+MessageChatMemoryAdvisor
。 - 核心机制:
-
- 对话历史自动注入模型上下文(保留角色标识)。
- 内存存储会话数据(支持替换为数据库)。
- 调用示例:
ChatMemory chatMemory = new InMemoryChatMemory();
chatClient = ChatClient.builder(dashscopeChatModel)
.defaultSystem(SYSTEM_PROMPT)
.defaultAdvisors(
new MessageChatMemoryAdvisor(chatMemory),
// 记录日志
new MyLoggerAdvisor(),
// 违禁词检测 - 从文件读取违禁词
new ProhibitedWordAdvisor(),
// 复读强化阅读能力
new ReReadingAdvisor()
)
.build();
五、多轮对话 AI 应用开发
LoveApp 的开发:
@Component
@Slf4j
public class LoveApp {
private static final String SYSTEM_PROMPT = "**恋爱大师·情感导航员** \n" + "10年情感咨询经验,擅长亲密关系理论与沟通技巧。提供中立建议,保护隐私。通过情绪确认、需求拆解(3-5维度)、心理学理论(如非暴力沟通)解析问题,给出2种实操策略(如\"我句式\"对话模拟),引导关系边界建立。示例:\"遗忘纪念日可能涉及记忆模式/爱意表达方式差异,建议用'观察+感受'沟通\"。不评判道德、不做医疗建议,严守伦理规范。您的专属情感顾问,随时为您解惑。";
private final ChatClient chatClient;
public LoveApp(ChatModel dashscopeChatModel) {
ChatMemory chatMemory = new InMemoryChatMemory();
chatClient = ChatClient.builder(dashscopeChatModel)
.defaultSystem(SYSTEM_PROMPT)
.defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory),
).build();
}
public String doChat(String message, String chatId) {
ChatResponse response = chatClient.prompt().user(message).advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId).param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)).call().chatResponse();
String content = response.getResult().getOutput().getText();
log.info("content: {}", content);
return content;
}
}
doChat 对话方法:
public String doChat(String message, String chatId) {
ChatResponse response = chatClient.prompt()
.user(message)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
.call()
.chatResponse();
String content = response.getResult().getOutput().getText();
log.info("content: {}", content);
return content;
testChat 测试代码:
@Test
void testChat() {
String chatId = UUID.randomUUID().toString();
// 第一轮
String message = "你好,我是程序员lenyan";
String answer = loveApp.doChat(message, chatId);
// 第二轮
message = "我想让另一半reyan更爱我";
answer = loveApp.doChat(message, chatId);
Assertions.assertNotNull(answer);
// 第三轮
message = "我的另一半叫什么来着?刚跟你说过,帮我回忆一下";
answer = loveApp.doChat(message, chatId);
Assertions.assertNotNull(answer);
}
测试结果
设置其历史记录
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 1))
六、扩展知识补充
自定义 Advisor
日志记录工具
/**
* 自定义日志 Advisor,打印用户输入和 AI 输出
*/
@Slf4j
public class MyLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
@Override
public String getName() {
return getClass().getSimpleName();
}
@Override
public int getOrder() {
return 0; // 执行顺序
}
// 请求前打印用户输入
private AdvisedRequest before(AdvisedRequest request) {
log.info("AI Request: {}", request.userText());
return request;
}
// 响应后打印 AI 输出
private void observeAfter(AdvisedResponse response) {
log.info("AI Response: {}", response.response().getResult().getOutput().getText());
}
// 同步调用处理
@Override
public AdvisedResponse aroundCall(AdvisedRequest req, CallAroundAdvisorChain chain) {
req = before(req);
AdvisedResponse res = chain.nextAroundCall(req);
observeAfter(res);
return res;
}
// 流式调用处理
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest req, StreamAroundAdvisorChain chain) {
req = before(req);
Flux<AdvisedResponse> res = chain.nextAroundStream(req);
return new MessageAggregator().aggregateAdvisedResponse(res, this::observeAfter);
}
}
测试如图:
违禁词工具
/**
* 违禁词校验 Advisor
* 检查用户输入是否包含违禁词
*/
@Slf4j
public class ProhibitedWordAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
private static final String DEFAULT_PROHIBITED_WORDS_FILE = "prohibited-words.txt";
private final List<String> prohibitedWords;
/**
* 创建默认违禁词Advisor,从默认文件读取违禁词列表
*/
public ProhibitedWordAdvisor() {
this.prohibitedWords = loadProhibitedWordsFromFile(DEFAULT_PROHIBITED_WORDS_FILE);
log.info("初始化违禁词Advisor,违禁词数量: {}", prohibitedWords.size());
}
/**
* 创建违禁词Advisor,从指定文件读取违禁词列表
*/
public ProhibitedWordAdvisor(String prohibitedWordsFile) {
this.prohibitedWords = loadProhibitedWordsFromFile(prohibitedWordsFile);
log.info("初始化违禁词Advisor,违禁词数量: {}", prohibitedWords.size());
}
/**
* 从文件加载违禁词列表
*/
private List<String> loadProhibitedWordsFromFile(String filePath) {
try {
var resource = new ClassPathResource(filePath);
var reader = new BufferedReader(
new InputStreamReader(resource.getInputStream(), StandardCharsets.UTF_8));
List<String> words = reader.lines()
.filter(StringUtils::hasText)
.map(String::trim)
.collect(Collectors.toList());
log.info("从文件 {} 加载违禁词 {} 个", filePath, words.size());
return words;
} catch (Exception e) {
log.error("加载违禁词文件 {} 失败", filePath, e);
return new ArrayList<>();
}
}
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@Override
public int getOrder() {
return -100; // 确保在其他Advisor之前执行
}
/**
* 检查请求中是否包含违禁词
*/
private AdvisedRequest checkRequest(AdvisedRequest request) {
String userText = request.userText();
if (containsProhibitedWord(userText)) {
log.warn("检测到违禁词在用户输入中: {}", userText);
throw new ProhibitedWordException("用户输入包含违禁词");
}
return request;
}
/**
* 检查文本中是否包含违禁词
*/
private boolean containsProhibitedWord(String text) {
if (!StringUtils.hasText(text)) {
return false;
}
for (String word : prohibitedWords) {
if (text.toLowerCase().contains(word.toLowerCase())) {
return true;
}
}
return false;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
return chain.nextAroundCall(checkRequest(advisedRequest));
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
return chain.nextAroundStream(checkRequest(advisedRequest));
}
/**
* 违禁词异常
*/
public static class ProhibitedWordException extends RuntimeException {
public ProhibitedWordException(String message) {
super(message);
}
}
}
测试如图:
提高 AI 推理能力Advisor
package com.lenyan.lenaiagent.advisor;
import org.springframework.ai.chat.client.advisor.api.*;
import reactor.core.publisher.Flux;
import java.util.HashMap;
import java.util.Map;
/**
* 自定义 Re2 Advisor
* 可提高大型语言模型的推理能力
*/
public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
/**
* 执行请求前,改写 Prompt
*
* @param advisedRequest
* @return
*/
private AdvisedRequest before(AdvisedRequest advisedRequest) {
Map<String, Object> advisedUserParams = new HashMap<>(advisedRequest.userParams());
advisedUserParams.put("re2_input_query", advisedRequest.userText());
return AdvisedRequest.from(advisedRequest)
.userText("""
{re2_input_query}
Read the question again: {re2_input_query}
""")
.userParams(advisedUserParams)
.build();
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
return chain.nextAroundCall(this.before(advisedRequest));
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
return chain.nextAroundStream(this.before(advisedRequest));
}
@Override
public int getOrder() {
return 0;
}
@Override
public String getName() {
return this.getClass().getSimpleName();
}
}
测试如图:
结构化输出 - 恋爱报告功能开发
@Test
void doChatWithReport() {
String chatId = UUID.randomUUID().toString();
String message = "你好,我是程序员lenyan,我想让另一半reyan更爱我,但我不知道该怎么做";
LoveApp.LoveReport loveReport = loveApp.doChatWithReport(message, chatId);
Assertions.assertNotNull(loveReport);
}
/**
* AI 恋爱报告功能(实战结构化输出)
*
* @param message
* @param chatId
* @return
*/
public LoveReport doChatWithReport(String message, String chatId) {
LoveReport loveReport = chatClient.prompt()
.system(SYSTEM_PROMPT + "每次对话后都要生成恋爱结果,标题为{用户名}的恋爱报告,内容为建议列表")
.user(message)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
.call().entity(LoveReport.class);
log.info("loveReport: {}", loveReport);
return loveReport;
}
测试如图:
对话记忆持久化
kryo 文件读取持久化
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId>
<version>5.6.2</version>
</dependency>
/**
* 基于文件持久化的对话记忆
*/
@Slf4j
public class FileBasedChatMemory implements ChatMemory {
private final String baseDir;
private static final Kryo kryo;
static {
kryo = new Kryo();
kryo.setRegistrationRequired(false);
kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
}
public FileBasedChatMemory(String dir) {
this.baseDir = dir;
new File(dir).mkdirs();
}
@Override
public void add(String conversationId, List<Message> messages) {
var existingMessages = getOrCreateConversation(conversationId);
existingMessages.addAll(messages);
saveConversation(conversationId, existingMessages);
}
@Override
public List<Message> get(String conversationId, int lastN) {
var allMessages = getOrCreateConversation(conversationId);
return allMessages.stream()
.skip(Math.max(0, allMessages.size() - lastN))
.toList();
}
@Override
public void clear(String conversationId) {
var file = getConversationFile(conversationId);
if (file.exists()) {
file.delete();
}
}
private List<Message> getOrCreateConversation(String conversationId) {
var file = getConversationFile(conversationId);
if (!file.exists()) {
return new ArrayList<>();
}
try (var input = new Input(new FileInputStream(file))) {
return kryo.readObject(input, ArrayList.class);
} catch (Exception e) {
log.error("读取对话记录失败: {}", conversationId, e);
return new ArrayList<>();
}
}
private void saveConversation(String conversationId, List<Message> messages) {
var file = getConversationFile(conversationId);
try (var output = new Output(new FileOutputStream(file))) {
kryo.writeObject(output, messages);
} catch (Exception e) {
log.error("保存对话记录失败: {}", conversationId, e);
}
}
private File getConversationFile(String conversationId) {
return new File(baseDir, conversationId + ".kryo");
}
}
测试如图:
一下数据库持久化数据库:
-- 创建数据库(如果不存在)
CREATE DATABASE IF NOT EXISTS lenai CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
-- 使用数据库
USE lenai;
-- 创建对话记忆表
CREATE TABLE IF NOT EXISTS chatmemory (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
conversation_id VARCHAR(255) NOT NULL,
message_order INT NOT NULL,
message_type VARCHAR(50) NOT NULL,
content TEXT NOT NULL,
message_json TEXT NOT NULL,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
is_delete BOOLEAN DEFAULT 0,
INDEX idx_conversation_id (conversation_id),
INDEX idx_conversation_order (conversation_id, message_order),
INDEX idx_is_delete (is_delete)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
MySQL 原生 JDBC 持久化
/**
* MySQL实现的对话记忆
* 将对话内容持久化到MySQL数据库
*/
@Component
@Slf4j
public class MySQLChatMemory implements ChatMemory {
private final JdbcTemplate jdbcTemplate;
private final JSONConfig jsonConfig;
public MySQLChatMemory(DataSource dataSource) {
this.jdbcTemplate = new JdbcTemplate(dataSource);
this.jsonConfig = new JSONConfig().setIgnoreNullValue(true);
log.info("初始化MySQL对话记忆");
}
@Override
@Transactional
public void add(String conversationId, Message message) {
if (message != null && conversationId != null) {
List<Message> messages = Collections.singletonList(message);
add(conversationId, messages);
}
}
@Override
@Transactional
public void add(String conversationId, List<Message> messages) {
if (messages == null || messages.isEmpty() || conversationId == null) {
return;
}
// 获取当前最大序号
Integer maxOrder = getMaxOrder(conversationId).orElse(0);
int nextOrder = maxOrder + 1;
// 使用批处理提高效率
String insertSql = "INSERT INTO chatmemory (conversation_id, message_order, message_type, content, message_json, create_time, update_time, is_delete) VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
log.info("添加消息到会话 {}, 消息数量: {}", conversationId, messages.size());
jdbcTemplate.batchUpdate(insertSql, messages, messages.size(), (ps, message) -> {
int order = nextOrder + messages.indexOf(message);
String messageJson = serializeMessage(message);
String content = message.getText();
Timestamp now = Timestamp.valueOf(LocalDateTime.now());
ps.setString(1, conversationId);
ps.setInt(2, order);
ps.setString(3, message.getMessageType().toString());
ps.setString(4, content);
ps.setString(5, messageJson);
ps.setTimestamp(6, now); // create_time
ps.setTimestamp(7, now); // update_time
ps.setBoolean(8, false); // is_delete = 0
});
}
@Override
public List<Message> get(String conversationId, int lastN) {
String sql;
Object[] params;
// 修改查询逻辑:lastN > 0 时获取前N条消息,而不是最后N条
if (lastN > 0) {
sql = "SELECT message_json, message_type, content FROM chatmemory " +
"WHERE conversation_id = ? AND is_delete = 0 ORDER BY message_order DESC LIMIT ?";
params = new Object[] { conversationId, lastN };
} else {
sql = "SELECT message_json, message_type, content FROM chatmemory " +
"WHERE conversation_id = ? AND is_delete = 0 ORDER BY message_order DESC";
params = new Object[] { conversationId };
}
List<Message> messages = executeMessageQuery(sql, params);
log.info("从会话 {} 中检索到 {} 条消息", conversationId, messages.size());
return messages;
}
@Override
@Transactional
public void clear(String conversationId) {
// 将物理删除改为逻辑删除
String sql = "UPDATE chatmemory SET is_delete = 1, update_time = ? WHERE conversation_id = ? AND is_delete = 0";
Timestamp now = Timestamp.valueOf(LocalDateTime.now());
Object[] params = new Object[] { now, conversationId };
int count = jdbcTemplate.update(sql, params);
log.info("从会话 {} 中逻辑删除 {} 条消息", conversationId, count);
}
/**
* 获取会话中最大的消息序号
*/
private Optional<Integer> getMaxOrder(String conversationId) {
String sql = "SELECT MAX(message_order) FROM chatmemory WHERE conversation_id = ? AND is_delete = 0";
Integer result = jdbcTemplate.queryForObject(sql, Integer.class, conversationId);
return Optional.ofNullable(result);
}
/**
* 将消息序列化为JSON字符串
*/
private String serializeMessage(Message message) {
Map<String, Object> map = new HashMap<>();
map.put("type", message.getMessageType().toString());
map.put("text", message.getText());
// 添加消息类名,便于反序列化
if (message instanceof UserMessage) {
map.put("messageClass", "UserMessage");
} else if (message instanceof AssistantMessage) {
map.put("messageClass", "AssistantMessage");
} else if (message instanceof SystemMessage) {
map.put("messageClass", "SystemMessage");
} else {
map.put("messageClass", "OtherMessage");
}
return JSONUtil.toJsonStr(map, jsonConfig);
}
/**
* 从JSON字符串反序列化消息
*/
private Message deserializeMessage(String messageJson, String messageType, String content) {
switch (messageType) {
case "USER":
return new UserMessage(content);
case "ASSISTANT":
return new AssistantMessage(content);
case "SYSTEM":
return new SystemMessage(content);
default:
log.warn("未知的消息类型: {}", messageType);
return new AssistantMessage("未知消息类型: " + content);
}
}
/**
* 执行消息查询并返回结果列表
*/
private List<Message> executeMessageQuery(String sql, Object[] params) {
log.info("SQL: {}, 参数: {}", sql, Arrays.toString(params));
return jdbcTemplate.query(sql, params, (rs, rowNum) -> {
String messageJson = rs.getString("message_json");
String messageType = rs.getString("message_type");
String content = rs.getString("content");
return deserializeMessage(messageJson, messageType, content);
}).stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
}
测试如图:
MyBatis-Plus 框架持久化
首先根据 MyBatis-Plus 生成 相关文件
实体类:
/**
* 聊天记忆实体类
*
* @TableName chatmemory
*/
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
@TableName(value = "chatmemory")
public class ChatMemory implements Serializable {
/**
* 主键ID
*/
@TableId(value = "id", type = IdType.AUTO)
private Long id;
/**
* 会话ID
*/
@TableField(value = "conversation_id")
private String conversationId;
/**
* 消息顺序
*/
@TableField(value = "message_order")
private Integer messageOrder;
/**
* 消息类型
*/
@TableField(value = "message_type")
private String messageType;
/**
* 消息内容
*/
@TableField(value = "content")
private String content;
/**
* 消息JSON
*/
@TableField(value = "message_json")
private String messageJson;
/**
* 创建时间
*/
@TableField(value = "create_time")
private Date createTime;
/**
* 更新时间
*/
@TableField(value = "update_time")
private Date updateTime;
/**
* 是否删除
*/
@TableField(value = "is_delete")
@TableLogic
private Boolean isDelete;
@TableField(exist = false)
private static final long serialVersionUID = 1L;
}
Service:
/**
* 聊天记忆服务接口
*/
public interface ChatMemoryService extends IService<ChatMemory> {
/**
* 添加多条消息
*
* @param conversationId 会话ID
* @param messages 消息列表
*/
void addMessages(String conversationId, List<Message> messages);
/**
* 获取会话消息
*
* @param conversationId 会话ID
* @param lastN 获取的消息数量,正数表示获取前N条,0或负数表示获取全部
* @return 消息列表
*/
List<Message> getMessages(String conversationId, int lastN);
/**
* 清除会话消息(逻辑删除)
*
* @param conversationId 会话ID
*/
void clearMessages(String conversationId);
}
Service 实现类:
/**
* 聊天记忆服务实现类
*/
@Slf4j
@Service
public class ChatMemoryServiceImpl extends ServiceImpl<ChatMemoryMapper, ChatMemory>
implements ChatMemoryService {
private final JSONConfig jsonConfig;
public ChatMemoryServiceImpl() {
this.jsonConfig = new JSONConfig().setIgnoreNullValue(true);
log.info("初始化Mybatis-Plus聊天记忆服务");
}
@Override
@Transactional
public void addMessages(String conversationId, List<Message> messages) {
if (messages == null || messages.isEmpty() || conversationId == null) {
return;
}
// 获取当前最大序号
Integer maxOrder = baseMapper.getMaxOrder(conversationId);
int nextOrder = (maxOrder != null ? maxOrder : 0) + 1;
// 将SpringAI消息转换为实体
List<ChatMemory> entities = new ArrayList<>();
for (int i = 0; i < messages.size(); i++) {
Message message = messages.get(i);
int order = nextOrder + i;
ChatMemory entity = ChatMemory.builder()
.conversationId(conversationId)
.messageOrder(order)
.messageType(message.getMessageType().toString())
.content(message.getText())
.messageJson(serializeMessage(message))
.createTime(new Date())
.updateTime(new Date())
.isDelete(false)
.build();
entities.add(entity);
}
// 批量保存
saveBatch(entities);
log.info("已添加 {} 条消息到会话 {}", messages.size(), conversationId);
}
@Override
public List<Message> getMessages(String conversationId, int lastN) {
List<ChatMemory> entities;
if (lastN > 0) {
// 获取最近的N条消息
entities = baseMapper.getLatestMessages(conversationId, lastN);
} else {
// 获取全部消息
LambdaQueryWrapper<ChatMemory> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(ChatMemory::getConversationId, conversationId)
.eq(ChatMemory::getIsDelete, false)
.orderByDesc(ChatMemory::getMessageOrder);
entities = list(wrapper);
}
// 将实体转换为SpringAI消息
List<Message> messages = convertToMessages(entities);
log.info("已从会话 {} 中检索到 {} 条消息", conversationId, messages.size());
return messages;
}
@Override
@Transactional
public void clearMessages(String conversationId) {
// 逻辑删除所有会话消息
int count = baseMapper.logicalDeleteByConversationId(conversationId);
log.info("已从会话 {} 中逻辑删除 {} 条消息", conversationId, count);
}
/**
* 将消息序列化为JSON字符串
*/
private String serializeMessage(Message message) {
Map<String, Object> map = new HashMap<>();
map.put("type", message.getMessageType().toString());
map.put("text", message.getText());
// 添加消息类名,便于反序列化
if (message instanceof UserMessage) {
map.put("messageClass", "UserMessage");
} else if (message instanceof AssistantMessage) {
map.put("messageClass", "AssistantMessage");
} else if (message instanceof SystemMessage) {
map.put("messageClass", "SystemMessage");
} else {
map.put("messageClass", "OtherMessage");
}
return JSONUtil.toJsonStr(map, jsonConfig);
}
/**
* 将实体列表转换为SpringAI消息列表
*/
private List<Message> convertToMessages(List<ChatMemory> entities) {
return entities.stream()
.map(this::convertToMessage)
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
/**
* 将单个实体转换为SpringAI消息
*/
private Message convertToMessage(ChatMemory entity) {
String messageType = entity.getMessageType();
String content = entity.getContent();
// 基于消息类型创建相应的消息实例
switch (messageType) {
case "USER":
return new UserMessage(content);
case "ASSISTANT":
return new AssistantMessage(content);
case "SYSTEM":
return new SystemMessage(content);
default:
log.warn("未知的消息类型: {}", messageType);
return new AssistantMessage("未知消息类型: " + content);
}
}
}
Mapper 数据层:
package com.lenyan.lenaiagent.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.lenyan.lenaiagent.domain.ChatMemory;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.apache.ibatis.annotations.Update;
import java.util.List;
/**
* <p>
* Mapper 接口
* </p>
*
* @author lenyan
* @since 2025-04-29
*/
@Mapper
public interface ChatMemoryMapper extends BaseMapper<ChatMemory> {
/**
* 获取最大消息序号
*/
@Select("SELECT MAX(message_order) FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0")
Integer getMaxOrder(@Param("conversationId") String conversationId);
/**
* 获取会话消息数量
*/
@Select("SELECT COUNT(*) FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0")
int getMessageCount(@Param("conversationId") String conversationId);
/**
* 逻辑删除会话消息
*/
@Update("UPDATE chatmemory SET is_delete = 1, update_time = NOW() WHERE conversation_id = #{conversationId} AND is_delete = 0")
int logicalDeleteByConversationId(@Param("conversationId") String conversationId);
/**
* 获取最近消息,按消息顺序降序
*/
@Select("SELECT * FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0 ORDER BY message_order DESC LIMIT #{limit}")
List<ChatMemory> getLatestMessages(@Param("conversationId") String conversationId, @Param("limit") int limit);
/**
* 分页获取消息
*/
@Select("SELECT * FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0 ORDER BY message_order DESC LIMIT #{pageSize} OFFSET #{offset}")
List<ChatMemory> getMessagesPaginated(@Param("conversationId") String conversationId,
@Param("pageSize") int pageSize, @Param("offset") int offset);
/**
* 获取指定偏移和数量的消息
*/
@Select("SELECT * FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0 ORDER BY message_order DESC LIMIT #{limit} OFFSET #{offset}")
List<ChatMemory> getMessagesWithOffset(@Param("conversationId") String conversationId, @Param("limit") int limit,
@Param("offset") int offset);
}
最后简单实现 MybatisPlusChatMemory
/**
* 基于Mybatis-Plus实现的对话记忆
* 使用ChatMemoryService进行数据库操作
*/
@Component
@Slf4j
public class MybatisPlusChatMemory implements ChatMemory {
private final ChatMemoryService chatMemoryService;
public MybatisPlusChatMemory(ChatMemoryService chatMemoryService) {
this.chatMemoryService = chatMemoryService;
log.info("初始化Mybatis-Plus对话记忆");
}
@Override
public void add(String conversationId, List<Message> messages) {
chatMemoryService.addMessages(conversationId, messages);
}
@Override
public List<Message> get(String conversationId, int lastN) {
return chatMemoryService.getMessages(conversationId, lastN);
}
@Override
public void clear(String conversationId) {
chatMemoryService.clearMessages(conversationId);
}
}
测试如图:
最后
最后,我叫 lenyan~ 也会持续学习更进 AI知识。让我们共进 AI 大时代。
Github:https://github.com/lenyanjgk
CSDN:lenyan~-CSDN博客
觉得有用的话可以点点赞 (/ω\),支持一下。
如果愿意的话关注一下。会对你有更多的帮助。
每周都会不定时更新哦 >人< 。