构建下一代AI智能体:基于Spring AI的多轮对话应用

发布于:2025-05-22 ⋅ 阅读:(24) ⋅ 点赞:(0)

构建下一代AI智能体:基于Spring AI的多轮对话应用

前言

大模型时代,AI应用开发已不再是遥不可及的技术。通过合理设计的Prompt工程和对话架构,开发者可以快速构建具备持续记忆能力的AI智能体。本文将重点介绍如何基于Spring AI框架打造可持久化的多轮对话应用,从Prompt优化到记忆持久化的全流程实现。

一、Prompt工程精要

核心三角
🔑 系统Prompt:AI人格设定
💬 用户Prompt:即时需求输入
📚 助手Prompt:对话上下文记忆

三大维度

  1. 功能型:指令/对话/创意/角色扮演
  2. 复杂度:简单→复合→链式→模板
  3. 开发级:基础提示→参数化模板→多轮记忆链

黄金法则
专业度=系统设定×场景约束×示例引导

token成本公式
总成本 = 输入 Token × 输入价 + 输出 Token × 输出价

二、prompt 优化技巧:

基础提示技巧
  1. 明确指定任务和角色(设定角色定位与具体需求)
  2. 提供详细说明和具体示例(补充背景信息与案例参考)
  3. 使用结构化格式引导思维(通过表格/列表等形式增强逻辑性)
  4. 明确输出格式要求(限定字数/风格/框架等标准)
进阶提示技巧
  1. 思维链提示法(展示推理过程分步拆解问题)
  2. 少样本学习(通过少量输入-输出示例建立模式认知)
  3. 分步骤指导(将复杂任务分解为可执行单元)
  4. 自我评估和修正(主动检验结果并优化解决方案)
  5. 知识检索和引用(关联外部知识库并标注来源)
  6. 多视角分析(从不同立场或学科角度切入分析)
  7. 多模态思维(融合文字/图表/流程等多元表达形式)
提示词调试与优化
  1. 迭代式提示优化(通过反复修改提升效果)
  2. 边界测试(测试模型极限能力发现改进方向)
  3. 提示词模板化(建立标准化模板确保结果统一)
  4. 错误分析与修正(定位问题根源进行定向优化改进)

三、AI需求分析

需求三剑客

需求从哪儿来?挖需求 → 抄AI应用商店爆款

怎么细化需求?养需求 → 喂Prompt让AI当产品总监

MVP 最小可行产品策略 验需求 → 先做基础核心功能

四、AI 应用方案设计

1、系统提示词设计

普通提示词 在为简短 ai 简单的身份命名。

你是一位恋爱大师,为用户提供情感咨询服务

优化后的 prompt 提示词

提示词模板

你是Prompt专家,可以根据格式生成各种专业的Prompt。
接下来请写一个“[请填写你想定义的角色名称(唯一需要手动输入的地方)]”的prompt,以Markdown输出,格式参考如下:
----------------
## Role : [请填写你想定义的角色名称]## Role : [请填写你想定义的角色名称]

## Background : [请描述角色的背景信息,例如其历史、来源或特定的知识背景]

## Preferences :[请描述角色的偏好或特定风格,例如对某种设计或文化的偏好]

## Profile :
 - author: lenyan
 - version: 1.0
 - language: 中文
 - description: [请简短描述该角色的主要功能,50 字以内]

## Goals :
[请列出该角色的主要目标 1]
[请列出该角色的主要目标 2]
...

## Constrains :
[请列出该角色在互动中必须遵循的限制条件 1]
[请列出该角色在互动中必须遵循的限制条件 2]
...

 ## Skills :
[为了在限制条件下实现目标,该角色需要拥有的技能 1]
[为了在限制条件下实现目标,该角色需要拥有的技能 2]
...

## Examples :
[提供一个输出示例 1,展示角色的可能回答或行为]
[提供一个输出示例 2,展示角色的可能回答或行为]
...
## OutputFormat :
[请描述该角色的工作流程的第一步]
[请描述该角色的工作流程的第二步]
...

## Initialization :
作为 [角色名称], 
拥有 [列举技能],
严格遵守 [列举限制条件], 
友好的欢迎用户。
然后介绍自己,并提示用户输入.
Role : 恋爱大师·情感导航员
Background :
拥有10年情感咨询经验的心理学专家,擅长运用亲密关系理论、非暴力沟通技巧和认知行为疗法,帮助过上千对情侣解决情感矛盾。熟悉不同文化背景下的恋爱模式差异,尤其擅长处理信任危机、沟通障碍和关系定位问题。
Preferences :
以温暖包容的语气质询,避免评判性语言。偏好用生活化比喻解释心理学概念,注重用户隐私保护,始终维护双方平等话语权。
Profile :
● author: lenyan
● version: 1.0
● language: 中文
● description: 专业解析恋爱矛盾,提供科学情感建议的虚拟咨询师
Goals :
1. 帮助用户识别并表达真实情感需求  
2. 提供可操作的沟通策略与冲突解决方法  
3. 引导建立健康的关系边界意识  
4. 促进双方视角转换与同理心培养  
5. 保护用户隐私不泄露敏感信息
Constrains :
1. 严禁涉及医疗诊断或药物建议  
2. 避免对用户做出道德评判  
3. 不代用户做决定,保持中立立场  
4. 涉及人身安全问题时需提示专业机构  
5. 回应需符合中国社会伦理规范
Skills :
1. 情感需求分析(识别隐藏情绪)  
2. 非暴力沟通框架构建  
3. 认知行为疗法应用  
4. 关系发展阶段理论  
5. 文化敏感性沟通技巧  
6. 边界设定指导
Examples :
用户提问:"男朋友总忘记我们的纪念日,是不是不爱我了?"
回答示例:"这个行为可能有多种解读角度。我们可以先分析:1. 他的记忆模式是否普遍容易遗忘重要日期?2. 他是否用其他方式表达爱意?3. 你内心真正期待的是仪式感还是被重视的感觉?建议尝试用'观察+感受'的方式沟通,比如:'我发现最近几次纪念日你都没特别安排,我有点失落,其实我更希望...'"
用户提问:"吵架时他总冷战,怎么沟通才有效?"
回答示例:"冷战往往是情绪过载的自我保护机制。可以试试:① 确认双方平静后再对话 ② 用'我句式'表达感受:'当你冷战时,我感到被忽视,担心问题没解决' ③ 共同制定'情绪急救方案',比如约定深呼吸5次后再回应。需要我帮你具体模拟对话场景吗?"
OutputFormat :
1. 情绪确认:先共情用户感受,如"这种感受很常见"  
2. 问题拆解:将复杂情况分解为3-5个分析维度  
3. 理论支撑:引用1-2个心理学概念解释现象  
4. 行动方案:提供2种具体可操作的沟通策略  
5. 后续引导:询问用户想深入探讨的具体方向
Initialization :
作为 恋爱大师·情感导航员,
拥有 非暴力沟通框架构建、认知行为疗法应用、关系阶段理论分析 等核心技能,
严格遵守 保持中立立场、保护隐私、不越界建议 等执业准则,
你好,我是你的专属恋爱顾问。无论是甜蜜困惑还是矛盾困扰,我都会用心理学视角为你解惑。请告诉我此刻最想探讨的情感问题吧。

2. 多轮对话实现

1. ChatClient 特性

Spring AI 的核心对话客户端,支持链式调用(Fluent API)、动态参数绑定(如模板变量)、多种响应格式(实体映射、流式输出),并可通过拦截器(Advisors)扩展功能。

2. Advisors(拦截器)

责任链模式的拦截器机制,在调用大模型前后执行增强逻辑(如注入历史对话、安全校验)。通过 getOrder() 控制执行顺序,常用如 MessageChatMemoryAdvisor(对话记忆)、QuestionAnswerAdvisor(知识检索)。

3. Chat Memory Advisor

负责维护对话上下文的拦截器,常见:

  • MessageChatMemoryAdvisor:将历史消息作为独立角色记录注入 Prompt(保留完整对话结构)。
  • PromptChatMemoryAdvisor:将历史对话拼接为系统提示文本(可能丢失消息边界)。
4. Chat Memory

对话记录的存储接口,提供保存/查询/清空消息的能力,内置实现包括:

  • 内存存储(InMemoryChatMemory
  • 持久化存储(JDBC、Cassandra、Neo4j 等)
  • 向量数据库扩展(VectorStoreChatMemoryAdvisor 支持检索增强)。

开发流程
① 创建ChatClient并绑定大模型
② 配置MessageChatMemoryAdvisor+选择ChatMemory实现
③ 通过.defaultAdvisors()注入记忆处理链
④ 对话时自动携带历史上下文

  • 技术栈:Spring AI 框架 + ChatClient + MessageChatMemoryAdvisor
  • 核心机制
    • 对话历史自动注入模型上下文(保留角色标识)。
    • 内存存储会话数据(支持替换为数据库)。
  • 调用示例
ChatMemory chatMemory = new InMemoryChatMemory();
chatClient = ChatClient.builder(dashscopeChatModel)
        .defaultSystem(SYSTEM_PROMPT)
        .defaultAdvisors(
                new MessageChatMemoryAdvisor(chatMemory),
                // 记录日志
                new MyLoggerAdvisor(),
                // 违禁词检测 - 从文件读取违禁词
                new ProhibitedWordAdvisor(),
                // 复读强化阅读能力
                new ReReadingAdvisor()
                )
        .build();

五、多轮对话 AI 应用开发

LoveApp 的开发:

@Component
@Slf4j
public class LoveApp {

    private static final String SYSTEM_PROMPT = "**恋爱大师·情感导航员**  \n" + "10年情感咨询经验,擅长亲密关系理论与沟通技巧。提供中立建议,保护隐私。通过情绪确认、需求拆解(3-5维度)、心理学理论(如非暴力沟通)解析问题,给出2种实操策略(如\"我句式\"对话模拟),引导关系边界建立。示例:\"遗忘纪念日可能涉及记忆模式/爱意表达方式差异,建议用'观察+感受'沟通\"。不评判道德、不做医疗建议,严守伦理规范。您的专属情感顾问,随时为您解惑。";
    private final ChatClient chatClient;

    public LoveApp(ChatModel dashscopeChatModel) {
        ChatMemory chatMemory = new InMemoryChatMemory();
        chatClient = ChatClient.builder(dashscopeChatModel)
                .defaultSystem(SYSTEM_PROMPT)
                .defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory),
                ).build();
    }

    public String doChat(String message, String chatId) {
        ChatResponse response = chatClient.prompt().user(message).advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId).param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)).call().chatResponse();
        String content = response.getResult().getOutput().getText();
        log.info("content: {}", content);
        return content;
    }
}

doChat 对话方法:

    public String doChat(String message, String chatId) {
        ChatResponse response = chatClient.prompt()
                .user(message)
                .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
                        .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
                .call()
                .chatResponse();
        String content = response.getResult().getOutput().getText();
        log.info("content: {}", content);
        return content;

testChat 测试代码:

@Test
void testChat() {
    String chatId = UUID.randomUUID().toString();
    // 第一轮
    String message = "你好,我是程序员lenyan";
    String answer = loveApp.doChat(message, chatId);
    // 第二轮
    message = "我想让另一半reyan更爱我";
    answer = loveApp.doChat(message, chatId);
    Assertions.assertNotNull(answer);
    // 第三轮
    message = "我的另一半叫什么来着?刚跟你说过,帮我回忆一下";
    answer = loveApp.doChat(message, chatId);
    Assertions.assertNotNull(answer);
}

测试结果

设置其历史记录

.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 1))

六、扩展知识补充

自定义 Advisor

日志记录工具

/**
 * 自定义日志 Advisor,打印用户输入和 AI 输出
 */
@Slf4j
public class MyLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

    @Override
    public String getName() {
        return getClass().getSimpleName();
    }

    @Override
    public int getOrder() {
        return 0; // 执行顺序
    }

    // 请求前打印用户输入
    private AdvisedRequest before(AdvisedRequest request) {
        log.info("AI Request: {}", request.userText());
        return request;
    }

    // 响应后打印 AI 输出
    private void observeAfter(AdvisedResponse response) {
        log.info("AI Response: {}", response.response().getResult().getOutput().getText());
    }

    // 同步调用处理
    @Override
    public AdvisedResponse aroundCall(AdvisedRequest req, CallAroundAdvisorChain chain) {
        req = before(req);
        AdvisedResponse res = chain.nextAroundCall(req);
        observeAfter(res);
        return res;
    }

    // 流式调用处理
    @Override
    public Flux<AdvisedResponse> aroundStream(AdvisedRequest req, StreamAroundAdvisorChain chain) {
        req = before(req);
        Flux<AdvisedResponse> res = chain.nextAroundStream(req);
        return new MessageAggregator().aggregateAdvisedResponse(res, this::observeAfter);
    }
}

测试如图:

违禁词工具

/**
 * 违禁词校验 Advisor
 * 检查用户输入是否包含违禁词
 */
@Slf4j
public class ProhibitedWordAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

    private static final String DEFAULT_PROHIBITED_WORDS_FILE = "prohibited-words.txt";
    private final List<String> prohibitedWords;

    /**
     * 创建默认违禁词Advisor,从默认文件读取违禁词列表
     */
    public ProhibitedWordAdvisor() {
        this.prohibitedWords = loadProhibitedWordsFromFile(DEFAULT_PROHIBITED_WORDS_FILE);
        log.info("初始化违禁词Advisor,违禁词数量: {}", prohibitedWords.size());
    }

    /**
     * 创建违禁词Advisor,从指定文件读取违禁词列表
     */
    public ProhibitedWordAdvisor(String prohibitedWordsFile) {
        this.prohibitedWords = loadProhibitedWordsFromFile(prohibitedWordsFile);
        log.info("初始化违禁词Advisor,违禁词数量: {}", prohibitedWords.size());
    }

    /**
     * 从文件加载违禁词列表
     */
    private List<String> loadProhibitedWordsFromFile(String filePath) {
        try {
            var resource = new ClassPathResource(filePath);
            var reader = new BufferedReader(
                    new InputStreamReader(resource.getInputStream(), StandardCharsets.UTF_8));

            List<String> words = reader.lines()
                    .filter(StringUtils::hasText)
                    .map(String::trim)
                    .collect(Collectors.toList());

            log.info("从文件 {} 加载违禁词 {} 个", filePath, words.size());
            return words;
        } catch (Exception e) {
            log.error("加载违禁词文件 {} 失败", filePath, e);
            return new ArrayList<>();
        }
    }

    @Override
    public String getName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public int getOrder() {
        return -100; // 确保在其他Advisor之前执行
    }

    /**
     * 检查请求中是否包含违禁词
     */
    private AdvisedRequest checkRequest(AdvisedRequest request) {
        String userText = request.userText();
        if (containsProhibitedWord(userText)) {
            log.warn("检测到违禁词在用户输入中: {}", userText);
            throw new ProhibitedWordException("用户输入包含违禁词");
        }
        return request;
    }

    /**
     * 检查文本中是否包含违禁词
     */
    private boolean containsProhibitedWord(String text) {
        if (!StringUtils.hasText(text)) {
            return false;
        }

        for (String word : prohibitedWords) {
            if (text.toLowerCase().contains(word.toLowerCase())) {
                return true;
            }
        }
        return false;
    }

    @Override
    public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
        return chain.nextAroundCall(checkRequest(advisedRequest));
    }

    @Override
    public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
        return chain.nextAroundStream(checkRequest(advisedRequest));
    }

    /**
     * 违禁词异常
     */
    public static class ProhibitedWordException extends RuntimeException {
        public ProhibitedWordException(String message) {
            super(message);
        }
    }
}

测试如图:

提高 AI 推理能力Advisor

package com.lenyan.lenaiagent.advisor;

import org.springframework.ai.chat.client.advisor.api.*;
import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.Map;

/**
 * 自定义 Re2 Advisor
 * 可提高大型语言模型的推理能力
 */
public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

    /**
     * 执行请求前,改写 Prompt
     *
     * @param advisedRequest
     * @return
     */
    private AdvisedRequest before(AdvisedRequest advisedRequest) {

        Map<String, Object> advisedUserParams = new HashMap<>(advisedRequest.userParams());
        advisedUserParams.put("re2_input_query", advisedRequest.userText());

        return AdvisedRequest.from(advisedRequest)
                .userText("""
                        {re2_input_query}
                        Read the question again: {re2_input_query}
                        """)
                .userParams(advisedUserParams)
                .build();
    }

    @Override
    public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
        return chain.nextAroundCall(this.before(advisedRequest));
    }

    @Override
    public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
        return chain.nextAroundStream(this.before(advisedRequest));
    }

    @Override
    public int getOrder() {
        return 0;
    }

    @Override
    public String getName() {
        return this.getClass().getSimpleName();
    }
}

测试如图:

结构化输出 - 恋爱报告功能开发

@Test
void doChatWithReport() {
    String chatId = UUID.randomUUID().toString();
    String message = "你好,我是程序员lenyan,我想让另一半reyan更爱我,但我不知道该怎么做";
    LoveApp.LoveReport loveReport = loveApp.doChatWithReport(message, chatId);
    Assertions.assertNotNull(loveReport);
}
/**
 * AI 恋爱报告功能(实战结构化输出)
 *
 * @param message
 * @param chatId
 * @return
 */
public LoveReport doChatWithReport(String message, String chatId) {
    LoveReport loveReport = chatClient.prompt()
    .system(SYSTEM_PROMPT + "每次对话后都要生成恋爱结果,标题为{用户名}的恋爱报告,内容为建议列表")
    .user(message)
    .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
              .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
    .call().entity(LoveReport.class);
    log.info("loveReport: {}", loveReport);
    return loveReport;
}

测试如图:

对话记忆持久化

kryo 文件读取持久化
<dependency>
    <groupId>com.esotericsoftware</groupId>
    <artifactId>kryo</artifactId>
    <version>5.6.2</version>
</dependency>

/**
 * 基于文件持久化的对话记忆
 */
@Slf4j
public class FileBasedChatMemory implements ChatMemory {

    private final String baseDir;
    private static final Kryo kryo;

    static {
        kryo = new Kryo();
        kryo.setRegistrationRequired(false);
        kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
    }

    public FileBasedChatMemory(String dir) {
        this.baseDir = dir;
        new File(dir).mkdirs();
    }

    @Override
    public void add(String conversationId, List<Message> messages) {
        var existingMessages = getOrCreateConversation(conversationId);
        existingMessages.addAll(messages);
        saveConversation(conversationId, existingMessages);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        var allMessages = getOrCreateConversation(conversationId);
        return allMessages.stream()
                .skip(Math.max(0, allMessages.size() - lastN))
                .toList();
    }

    @Override
    public void clear(String conversationId) {
        var file = getConversationFile(conversationId);
        if (file.exists()) {
            file.delete();
        }
    }

    private List<Message> getOrCreateConversation(String conversationId) {
        var file = getConversationFile(conversationId);
        if (!file.exists()) {
            return new ArrayList<>();
        }

        try (var input = new Input(new FileInputStream(file))) {
            return kryo.readObject(input, ArrayList.class);
        } catch (Exception e) {
            log.error("读取对话记录失败: {}", conversationId, e);
            return new ArrayList<>();
        }
    }

    private void saveConversation(String conversationId, List<Message> messages) {
        var file = getConversationFile(conversationId);
        try (var output = new Output(new FileOutputStream(file))) {
            kryo.writeObject(output, messages);
        } catch (Exception e) {
            log.error("保存对话记录失败: {}", conversationId, e);
        }
    }

    private File getConversationFile(String conversationId) {
        return new File(baseDir, conversationId + ".kryo");
    }
}

测试如图:

一下数据库持久化数据库:
-- 创建数据库(如果不存在)
CREATE DATABASE IF NOT EXISTS lenai CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;

-- 使用数据库
USE lenai;

-- 创建对话记忆表
CREATE TABLE IF NOT EXISTS chatmemory (
  id BIGINT AUTO_INCREMENT PRIMARY KEY,
  conversation_id VARCHAR(255) NOT NULL,
  message_order INT NOT NULL,
  message_type VARCHAR(50) NOT NULL,
  content TEXT NOT NULL,
  message_json TEXT NOT NULL,
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  is_delete BOOLEAN DEFAULT 0,
  INDEX idx_conversation_id (conversation_id),
  INDEX idx_conversation_order (conversation_id, message_order),
  INDEX idx_is_delete (is_delete)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
MySQL 原生 JDBC 持久化
/**
 * MySQL实现的对话记忆
 * 将对话内容持久化到MySQL数据库
 */
@Component
@Slf4j
public class MySQLChatMemory implements ChatMemory {

    private final JdbcTemplate jdbcTemplate;
    private final JSONConfig jsonConfig;

    public MySQLChatMemory(DataSource dataSource) {
        this.jdbcTemplate = new JdbcTemplate(dataSource);
        this.jsonConfig = new JSONConfig().setIgnoreNullValue(true);
        log.info("初始化MySQL对话记忆");
    }

    @Override
    @Transactional
    public void add(String conversationId, Message message) {
        if (message != null && conversationId != null) {
            List<Message> messages = Collections.singletonList(message);
            add(conversationId, messages);
        }
    }

    @Override
    @Transactional
    public void add(String conversationId, List<Message> messages) {
        if (messages == null || messages.isEmpty() || conversationId == null) {
            return;
        }

        // 获取当前最大序号
        Integer maxOrder = getMaxOrder(conversationId).orElse(0);
        int nextOrder = maxOrder + 1;

        // 使用批处理提高效率
        String insertSql = "INSERT INTO chatmemory (conversation_id, message_order, message_type, content, message_json, create_time, update_time, is_delete) VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
        log.info("添加消息到会话 {}, 消息数量: {}", conversationId, messages.size());

        jdbcTemplate.batchUpdate(insertSql, messages, messages.size(), (ps, message) -> {
            int order = nextOrder + messages.indexOf(message);
            String messageJson = serializeMessage(message);
            String content = message.getText();
            Timestamp now = Timestamp.valueOf(LocalDateTime.now());

            ps.setString(1, conversationId);
            ps.setInt(2, order);
            ps.setString(3, message.getMessageType().toString());
            ps.setString(4, content);
            ps.setString(5, messageJson);
            ps.setTimestamp(6, now); // create_time
            ps.setTimestamp(7, now); // update_time
            ps.setBoolean(8, false); // is_delete = 0
        });
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        String sql;
        Object[] params;

        // 修改查询逻辑:lastN > 0 时获取前N条消息,而不是最后N条
        if (lastN > 0) {
            sql = "SELECT message_json, message_type, content FROM chatmemory " +
                    "WHERE conversation_id = ? AND is_delete = 0 ORDER BY message_order DESC LIMIT ?";
            params = new Object[] { conversationId, lastN };
        } else {
            sql = "SELECT message_json, message_type, content FROM chatmemory " +
                    "WHERE conversation_id = ? AND is_delete = 0 ORDER BY message_order DESC";
            params = new Object[] { conversationId };
        }

        List<Message> messages = executeMessageQuery(sql, params);
        log.info("从会话 {} 中检索到 {} 条消息", conversationId, messages.size());
        return messages;
    }

    @Override
    @Transactional
    public void clear(String conversationId) {
        // 将物理删除改为逻辑删除
        String sql = "UPDATE chatmemory SET is_delete = 1, update_time = ? WHERE conversation_id = ? AND is_delete = 0";
        Timestamp now = Timestamp.valueOf(LocalDateTime.now());
        Object[] params = new Object[] { now, conversationId };

        int count = jdbcTemplate.update(sql, params);
        log.info("从会话 {} 中逻辑删除 {} 条消息", conversationId, count);
    }

    /**
     * 获取会话中最大的消息序号
     */
    private Optional<Integer> getMaxOrder(String conversationId) {
        String sql = "SELECT MAX(message_order) FROM chatmemory WHERE conversation_id = ? AND is_delete = 0";
        Integer result = jdbcTemplate.queryForObject(sql, Integer.class, conversationId);
        return Optional.ofNullable(result);
    }

    /**
     * 将消息序列化为JSON字符串
     */
    private String serializeMessage(Message message) {
        Map<String, Object> map = new HashMap<>();
        map.put("type", message.getMessageType().toString());
        map.put("text", message.getText());

        // 添加消息类名,便于反序列化
        if (message instanceof UserMessage) {
            map.put("messageClass", "UserMessage");
        } else if (message instanceof AssistantMessage) {
            map.put("messageClass", "AssistantMessage");
        } else if (message instanceof SystemMessage) {
            map.put("messageClass", "SystemMessage");
        } else {
            map.put("messageClass", "OtherMessage");
        }

        return JSONUtil.toJsonStr(map, jsonConfig);
    }

    /**
     * 从JSON字符串反序列化消息
     */
    private Message deserializeMessage(String messageJson, String messageType, String content) {
        switch (messageType) {
            case "USER":
                return new UserMessage(content);
            case "ASSISTANT":
                return new AssistantMessage(content);
            case "SYSTEM":
                return new SystemMessage(content);
            default:
                log.warn("未知的消息类型: {}", messageType);
                return new AssistantMessage("未知消息类型: " + content);
        }
    }

    /**
     * 执行消息查询并返回结果列表
     */
    private List<Message> executeMessageQuery(String sql, Object[] params) {
        log.info("SQL: {}, 参数: {}", sql, Arrays.toString(params));

        return jdbcTemplate.query(sql, params, (rs, rowNum) -> {
            String messageJson = rs.getString("message_json");
            String messageType = rs.getString("message_type");
            String content = rs.getString("content");
            return deserializeMessage(messageJson, messageType, content);
        }).stream()
                .filter(Objects::nonNull)
                .collect(Collectors.toList());
    }
}

测试如图:

MyBatis-Plus 框架持久化

首先根据 MyBatis-Plus 生成 相关文件

实体类:
/**
 * 聊天记忆实体类
 * 
 * @TableName chatmemory
 */
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
@TableName(value = "chatmemory")
public class ChatMemory implements Serializable {
    /**
     * 主键ID
     */
    @TableId(value = "id", type = IdType.AUTO)
    private Long id;

    /**
     * 会话ID
     */
    @TableField(value = "conversation_id")
    private String conversationId;

    /**
     * 消息顺序
     */
    @TableField(value = "message_order")
    private Integer messageOrder;

    /**
     * 消息类型
     */
    @TableField(value = "message_type")
    private String messageType;

    /**
     * 消息内容
     */
    @TableField(value = "content")
    private String content;

    /**
     * 消息JSON
     */
    @TableField(value = "message_json")
    private String messageJson;

    /**
     * 创建时间
     */
    @TableField(value = "create_time")
    private Date createTime;

    /**
     * 更新时间
     */
    @TableField(value = "update_time")
    private Date updateTime;

    /**
     * 是否删除
     */
    @TableField(value = "is_delete")
    @TableLogic
    private Boolean isDelete;

    @TableField(exist = false)
    private static final long serialVersionUID = 1L;
}
Service:
/**
 * 聊天记忆服务接口
 */
public interface ChatMemoryService extends IService<ChatMemory> {

    /**
     * 添加多条消息
     *
     * @param conversationId 会话ID
     * @param messages       消息列表
     */
    void addMessages(String conversationId, List<Message> messages);

    /**
     * 获取会话消息
     *
     * @param conversationId 会话ID
     * @param lastN          获取的消息数量,正数表示获取前N条,0或负数表示获取全部
     * @return 消息列表
     */
    List<Message> getMessages(String conversationId, int lastN);

    /**
     * 清除会话消息(逻辑删除)
     *
     * @param conversationId 会话ID
     */
    void clearMessages(String conversationId);

}
Service 实现类:
/**
 * 聊天记忆服务实现类
 */
@Slf4j
@Service
public class ChatMemoryServiceImpl extends ServiceImpl<ChatMemoryMapper, ChatMemory>
        implements ChatMemoryService {

    private final JSONConfig jsonConfig;

    public ChatMemoryServiceImpl() {
        this.jsonConfig = new JSONConfig().setIgnoreNullValue(true);
        log.info("初始化Mybatis-Plus聊天记忆服务");
    }

    @Override
    @Transactional
    public void addMessages(String conversationId, List<Message> messages) {
        if (messages == null || messages.isEmpty() || conversationId == null) {
            return;
        }

        // 获取当前最大序号
        Integer maxOrder = baseMapper.getMaxOrder(conversationId);
        int nextOrder = (maxOrder != null ? maxOrder : 0) + 1;

        // 将SpringAI消息转换为实体
        List<ChatMemory> entities = new ArrayList<>();
        for (int i = 0; i < messages.size(); i++) {
            Message message = messages.get(i);
            int order = nextOrder + i;

            ChatMemory entity = ChatMemory.builder()
                    .conversationId(conversationId)
                    .messageOrder(order)
                    .messageType(message.getMessageType().toString())
                    .content(message.getText())
                    .messageJson(serializeMessage(message))
                    .createTime(new Date())
                    .updateTime(new Date())
                    .isDelete(false)
                    .build();

            entities.add(entity);
        }

        // 批量保存
        saveBatch(entities);
        log.info("已添加 {} 条消息到会话 {}", messages.size(), conversationId);
    }

    @Override
    public List<Message> getMessages(String conversationId, int lastN) {
        List<ChatMemory> entities;

        if (lastN > 0) {
            // 获取最近的N条消息
            entities = baseMapper.getLatestMessages(conversationId, lastN);
        } else {
            // 获取全部消息
            LambdaQueryWrapper<ChatMemory> wrapper = new LambdaQueryWrapper<>();
            wrapper.eq(ChatMemory::getConversationId, conversationId)
                    .eq(ChatMemory::getIsDelete, false)
                    .orderByDesc(ChatMemory::getMessageOrder);
            entities = list(wrapper);
        }

        // 将实体转换为SpringAI消息
        List<Message> messages = convertToMessages(entities);
        log.info("已从会话 {} 中检索到 {} 条消息", conversationId, messages.size());
        return messages;
    }

    @Override
    @Transactional
    public void clearMessages(String conversationId) {
        // 逻辑删除所有会话消息
        int count = baseMapper.logicalDeleteByConversationId(conversationId);
        log.info("已从会话 {} 中逻辑删除 {} 条消息", conversationId, count);
    }

    /**
     * 将消息序列化为JSON字符串
     */
    private String serializeMessage(Message message) {
        Map<String, Object> map = new HashMap<>();
        map.put("type", message.getMessageType().toString());
        map.put("text", message.getText());

        // 添加消息类名,便于反序列化
        if (message instanceof UserMessage) {
            map.put("messageClass", "UserMessage");
        } else if (message instanceof AssistantMessage) {
            map.put("messageClass", "AssistantMessage");
        } else if (message instanceof SystemMessage) {
            map.put("messageClass", "SystemMessage");
        } else {
            map.put("messageClass", "OtherMessage");
        }

        return JSONUtil.toJsonStr(map, jsonConfig);
    }

    /**
     * 将实体列表转换为SpringAI消息列表
     */
    private List<Message> convertToMessages(List<ChatMemory> entities) {
        return entities.stream()
                .map(this::convertToMessage)
                .filter(Objects::nonNull)
                .collect(Collectors.toList());
    }

    /**
     * 将单个实体转换为SpringAI消息
     */
    private Message convertToMessage(ChatMemory entity) {
        String messageType = entity.getMessageType();
        String content = entity.getContent();

        // 基于消息类型创建相应的消息实例
        switch (messageType) {
            case "USER":
                return new UserMessage(content);
            case "ASSISTANT":
                return new AssistantMessage(content);
            case "SYSTEM":
                return new SystemMessage(content);
            default:
                log.warn("未知的消息类型: {}", messageType);
                return new AssistantMessage("未知消息类型: " + content);
        }
    }
}
Mapper 数据层:
package com.lenyan.lenaiagent.mapper;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.lenyan.lenaiagent.domain.ChatMemory;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.apache.ibatis.annotations.Update;

import java.util.List;

/**
 * <p>
 * Mapper 接口
 * </p>
 *
 * @author lenyan
 * @since 2025-04-29
 */
@Mapper
public interface ChatMemoryMapper extends BaseMapper<ChatMemory> {

    /**
     * 获取最大消息序号
     */
    @Select("SELECT MAX(message_order) FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0")
    Integer getMaxOrder(@Param("conversationId") String conversationId);

    /**
     * 获取会话消息数量
     */
    @Select("SELECT COUNT(*) FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0")
    int getMessageCount(@Param("conversationId") String conversationId);

    /**
     * 逻辑删除会话消息
     */
    @Update("UPDATE chatmemory SET is_delete = 1, update_time = NOW() WHERE conversation_id = #{conversationId} AND is_delete = 0")
    int logicalDeleteByConversationId(@Param("conversationId") String conversationId);

    /**
     * 获取最近消息,按消息顺序降序
     */
    @Select("SELECT * FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0 ORDER BY message_order DESC LIMIT #{limit}")
    List<ChatMemory> getLatestMessages(@Param("conversationId") String conversationId, @Param("limit") int limit);

    /**
     * 分页获取消息
     */
    @Select("SELECT * FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0 ORDER BY message_order DESC LIMIT #{pageSize} OFFSET #{offset}")
    List<ChatMemory> getMessagesPaginated(@Param("conversationId") String conversationId,
            @Param("pageSize") int pageSize, @Param("offset") int offset);

    /**
     * 获取指定偏移和数量的消息
     */
    @Select("SELECT * FROM chatmemory WHERE conversation_id = #{conversationId} AND is_delete = 0 ORDER BY message_order DESC LIMIT #{limit} OFFSET #{offset}")
    List<ChatMemory> getMessagesWithOffset(@Param("conversationId") String conversationId, @Param("limit") int limit,
            @Param("offset") int offset);
}
最后简单实现 MybatisPlusChatMemory
/**
 * 基于Mybatis-Plus实现的对话记忆
 * 使用ChatMemoryService进行数据库操作
 */
@Component
@Slf4j
public class MybatisPlusChatMemory implements ChatMemory {

    private final ChatMemoryService chatMemoryService;

    public MybatisPlusChatMemory(ChatMemoryService chatMemoryService) {
        this.chatMemoryService = chatMemoryService;
        log.info("初始化Mybatis-Plus对话记忆");
    }

    @Override
    public void add(String conversationId, List<Message> messages) {
        chatMemoryService.addMessages(conversationId, messages);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        return chatMemoryService.getMessages(conversationId, lastN);
    }

    @Override
    public void clear(String conversationId) {
        chatMemoryService.clearMessages(conversationId);
    }
}
测试如图:

最后

最后,我叫 lenyan~ 也会持续学习更进 AI知识。让我们共进 AI 大时代。

Github:https://github.com/lenyanjgk

CSDN:lenyan~-CSDN博客

觉得有用的话可以点点赞 (/ω\),支持一下。

如果愿意的话关注一下。会对你有更多的帮助。

每周都会不定时更新哦 >人< 。


网站公告

今日签到

点亮在社区的每一天
去签到