Spring AI 源码

发布于:2025-07-04 ⋅ 阅读:(18) ⋅ 点赞:(0)

目录
Spring AI 介绍
Spring AI 组件介绍
Spring AI 结构化输出
Srping AI 多模态
Spring AI 本地Ollama
Spring AI 源码
Spring AI Advisor机制
Spring AI Tool Calling
Spring AI MCP
Spring AI RAG
Spring AI Agent

Spring AI 是一个用于 AI 工程的应用程序框架。 其目标是将 Spring 生态系统设计原则(如可移植性和模块化设计)应用于 AI 领域,并将使用 POJO 作为应用程序的构建块推广到 AI 领域。
在这里插入图片描述
Spring AI 的核心是解决了 AI 集成的根本挑战:将您的企业数据和 API 与 AI 模型连接起来。

AI 模型

在这里插入图片描述

AI Embedding

嵌入是文本、图像或视频的数字表示形式,用于捕获输入之间的关系。
嵌入的工作原理是将文本、图像和视频转换为浮点数数组(称为向量)。 这些矢量旨在捕获文本、图像和视频的含义。 嵌入数组的长度称为向量的维数。
在这里插入图片描述

结构化输出

在这里插入图片描述

RAG

在这里插入图片描述

ETL

在这里插入图片描述

AI TOOL

在这里插入图片描述

代码包

spring-ai-commons

spring-ai-commons 是 Spring AI 框架中的基础通用模块,主要提供跨功能模块的公共组件和工具类支持。
一、模块定位
‌1. 基础支撑层‌

  • 作为 Spring AI 的底层通用包,包含不适合单独划分到特定模块的公共实现
  • 为其他模块(如模型交互、向量存储等)提供标准化工具类和接口定义

2‌. 代码复用中心‌

  • 集中管理重复使用的常量、异常处理、类型转换等基础逻辑23

二、核心功能
‌1. 通用工具类‌

  • 提供字符串处理、JSON 序列化等公共方法
  • 包含跨模块的配置解析和验证工具

2‌. 基础抽象接口‌

  • 定义模型加载、数据转换等通用接口规范
  • 支持模块间的标准化交互协议(如 MCP 相关基础)

3‌. 异常处理体系‌

  • 统一封装 AI 模型调用、数据处理等场景的异常类

spring-ai-template-st

spring-ai-template-st 是 Spring AI 框架中的字符串模板渲染组件,主要用于动态生成 AI 模型交互所需的提示词(Prompt)。
一、核心功能
‌1. 动态模板渲染‌

  • 支持通过占位符(如 {input})动态替换模板变量,生成结构化提示词
  • 示例模板:
"请用中文回答关于 {topic} 的问题,回答需包含以下要点:{requirements}"

2‌. 多角色提示支持‌

  • 可定义不同角色的消息模板(如 system、user 角色),适配多轮对话场景

‌3. 与 Prompt 类深度集成‌

  • 渲染结果可直接转换为 Prompt 对象,用于 ChatClient 调用

二、技术实现

  • 底层依赖‌:基于 StringTemplate 引擎实现变量替换8
  • 配置方式‌:通常与 spring-ai-model 模块配合使用,通过 @Configuration 注入模板实例

spring-ai-model

spring-ai-model 是 Spring AI 框架的核心模块之一,负责统一管理各类 AI 模型的交互逻辑和抽象接口。
一、模块定位
‌1. 模型抽象层‌

  • 提供跨模型服务商(如 OpenAI、DeepSeek、Hugging Face 等)的标准化调用接口
  • 通过 Model 抽象类定义通用交互协议,屏蔽不同 API 的底层差异37
  1. 多模态支持‌
  • 覆盖语言模型(如 ChatGPT)、图像生成(如 Stable Diffusion)、嵌入模型(Embedding)等 AI 能力

二、核心类与接口
‌1. 基础抽象类‌

  • Model:顶级抽象接口,定义 call() 方法处理 ModelRequest 并返回 ModelResponse
  • ModelRequest:封装输入指令(instructions)和模型参数(ChatOptions)
  • ModelResponse:包含输出结果(List)和元数据(ResponseMetadata)

2‌. 衍生模型接口‌

  • ChatModel:面向对话场景的扩展接口(如 ZhiPuAiChatModel 实现)38
  • EmbeddingModel:处理文本/图像向量化任务7

spring-ai-vector-store

spring-ai-vector-store 是 Spring AI 框架中专门用于处理向量数据库交互的核心模块,提供标准化的向量存储与检索能力。
一、核心功能
‌1. 统一接口抽象‌

  • 通过 VectorStore 接口定义标准化操作,包括文档写入(add)、删除(delete)和相似性搜索(similaritySearch)
  • 支持通过 Filter.Expression 实现条件化数据操作
  1. 多数据库适配‌
  • 内置支持 20+ 向量数据库(如 Pinecone、PgVector、Milvus 等),通过统一 API 屏蔽底层差异
  • 提供 SimpleVectorStore 作为轻量级内存实现,适用于开发测试

3‌. RAG 流程集成‌

  • 与嵌入模型(EmbeddingModel)协同,自动将文档转换为向量后存储
  • 支持检索增强生成(RAG)场景下的上下文检索

二、核心类与接口

1‌. 关键接口‌

  • VectorStore:顶层接口,继承 DocumentWriter 实现文档写入能力12
  • SearchRequest:封装相似性搜索的请求参数(如返回数量、相似度阈值)

‌2. 实现类示例‌

  • SimpleVectorStore:基于内存的并发安全实现,使用 ConcurrentHashMap 存储向量数据
  • PgVectorStore:集成 PostgreSQL 的 pgvector 扩展,支持自动建表和索引

‌3. 构建器模式‌

  • 通过 AbstractVectorStoreBuilder 抽象基类实现链式配置,子类覆盖具体数据库逻辑6

spring-ai-rag

spring-ai-rag 是 Spring AI 框架中实现检索增强生成(RAG)的核心模块,通过整合向量检索与大模型生成能力,解决传统大语言模型的静态知识局限性和幻觉问题。
一、核心架构
1‌. 模块化流程设计‌

  • 预检索阶段‌:支持查询转换(QueryTransformer)和扩展(如 RewriteQueryTransformer)
  • 检索阶段‌:通过 VectorStoreDocumentRetriever 实现相似性搜索,支持阈值过滤(similarityThreshold)
    ‌后处理阶段‌:提供文档重排、压缩和上下文增强(ContextualQueryAugmenter)

2‌. 关键组件‌

  • QuestionAnswerAdvisor:开箱即用的 RAG 流程封装,支持动态过滤表达式
  • RetrievalAugmentationAdvisor(孵化中):支持自定义顺序式 RAG 流程

二、技术特性
1‌. 动态知识整合‌

  • 通过外部知识库(如 Milvus、Elasticsearch)实时扩展模型知识,避免重新训练

‌2. 混合检索支持‌

  • 结合语义搜索与传统关键词检索,提升结果相关性

‌3. Spring 生态集成‌

  • 与 spring-ai-vector-store 深度协同,支持自动配置向量数据库连接

spring-ai-advisors-vector-store

spring-ai-advisors-vector-store 是 Spring AI 框架中专门用于结合向量存储(Vector Store)和拦截器(Advisors)技术的模块,旨在简化检索增强生成(RAG)流程的实现。
一、模块定位
1‌. RAG 流程封装‌

  • 提供开箱即用的 QuestionAnswerAdvisor,自动将向量存储检索结果注入大模型请求上下文,实现检索增强生成
  • 支持通过 RetrievalAugmentationAdvisor 动态扩展用户查询的上下文信息

2‌. 拦截器架构‌

  • 基于 Spring AOP 思想,通过 CallAroundAdvisor 拦截模型调用,在请求前后插入向量检索逻辑
  • 支持非流式(CallAroundAdvisorChain)和流式(StreamAroundAdvisorChain)两种处理模式

二、核心组件
‌1. 关键类与接口‌

QuestionAnswerAdvisor:内置 RAG 拦截器,自动关联 VectorStore 与 ChatModel
AdvisedRequest:封装原始请求和共享上下文(如检索到的文档列表)
AdvisedResponse:包含模型响应及增强后的元数据

2‌. 配置方式‌

// 构建时注册默认拦截器
ChatClient.builder(chatModel)
    .defaultAdvisors(
        new QuestionAnswerAdvisor(vectorStore)  // 绑定向量存储
    )
    .build();:ml-citation{ref="1,5" data="citationList"}

spring-ai-retry

spring-ai-retry 是 Spring AI 框架中专门处理 AI 模型调用重试机制的模块,通过标准化策略增强系统容错能力。
一、核心功能
1‌. 异常分类机制‌

  • 定义 TransientAiException(可重试异常)和 NonTransientAiException(不可重试异常)两类异常,明确重试边界
  • 自动识别网络超时、速率限制等临时性故障
  1. 策略配置‌
  • 支持最大重试次数(maxAttempts)、退避间隔(backoff)等参数定制
    内置指数退避算法避免雪崩效应

二、关键组件

‌1. 核心类‌

  • RetryUtils:提供重试逻辑的静态工具方法,支持同步/异步调用
  • RetryTemplate(扩展):与 Spring Retry 模块集成,支持复杂策略组合
  1. 配置示例‌
# application.properties
spring.ai.retry.max-attempts=3
spring.ai.retry.initial-interval=1000ms
spring.ai.retry.multiplier=2.0

spring-ai-client-chat

spring-ai-client-chat 是 Spring AI 框架中用于与大语言模型(LLM)交互的核心模块,提供标准化的聊天式 API 接口和高级功能封装。
一、核心组件
‌1. ChatClient 接口‌
-‌ 功能定位‌:统一多模型(如 OpenAI、Gemini、本地 LLM)的聊天交互,支持同步/流式响应、上下文管理和结构化输出

  • 方法特性‌:
// 链式构建请求
chatClient.prompt()
    .system("你是一名Java专家")  // 系统提示
    .user("解释Spring AOP原理")  // 用户输入
    .call()  // 同步执行
    .content();  // 获取文本响应

支持流式处理(stream())、实体映射(entity(Class))和参数动态注入(param())
2‌. ChatClientRequestBuilder‌

  • 提供 Fluent API 设计,支持温度(temperature)、最大 Token 数(maxTokens)等模型参数配置

二、关键特性
‌1. 多模型支持‌

  • 通过 ChatModel 抽象层兼容 20+ 模型提供商(如 Anthropic、智谱 AI),仅需更换依赖即可切换模型实现‌
  1. 上下文管理‌
  • 内置对话记忆(ChatMemory),自动维护多轮对话历史
  • 支持通过 Prompt 对象传递历史消息实现连续对话

3‌. 结构化输出‌

  • 自动将 JSON 响应映射为 Java 对象:
record Joke(String setup, String punchline) {}
Joke joke = chatClient.prompt()
    .user("讲个冷笑话")
    .call()
    .entity(Joke.class);  // 自动反序列化

适用于需要强类型响应的场景

spring-ai-mcp

spring-ai-mcp 是 Spring AI 框架中实现 ‌模型上下文协议(Model Context Protocol, MCP)‌ 的核心模块,提供标准化接口实现大语言模型(LLM)与外部工具、数据源的动态交互能力。
一、协议定位
1‌. 核心目标‌

  • 标准化 LLM 与外部资源(数据库、API、文件等)的交互协议,类似 AI 领域的 “HTTP 协议”
  • 支持动态工具发现、上下文感知和资源安全访问,解决传统 Function Calling 的碎片化问题

2‌. 架构分层‌

  • 传输层‌:支持 STDIO(进程间通信)和 HTTP SSE(事件流)两种传输模式
  • 会话层‌:通过 McpSession 管理通信状态与协议版本协商
  • 服务层‌:提供工具调用、资源管理和提示模板注入等标准化服务

二、核心组件
‌1. 服务端(McpServer)‌

  • 通过 @EnableMcpServer 注解快速启动,支持同步/异步操作模式
  • 关键能力:
// 示例:天气预报服务端
@McpTool(name = "weather", description = "查询城市天气")
public String getWeather(@Param("city") String city) {
    return weatherService.fetch(city);
}

支持工具自动注册、资源 URI 映射和结构化日志

2‌. 客户端(McpClient)‌

  • 内置协议协商机制,自动处理兼容性和能力发现
  • 多传输协议支持:
// 配置 SSE 客户端
McpClient client = new SseMcpClient("http://localhost:8080/mcp");

支持动态上下文注入和批量操作

spring-ai-spring-cloud-bindings

spring-ai-spring-cloud-bindings 是 Spring AI 框架中实现云服务自动化集成的核心模块,通过标准化绑定机制简化 AI 组件(如向量存储、模型服务)与云平台的连接配置。
一、核心功能
‌1. 云凭证自动注入‌

  • 自动从云平台(AWS/Azure/阿里云等)获取 API 密钥、访问令牌等凭证,通过环境变量或 application.yml 动态注入
  • 支持多环境隔离配置,如开发/生产环境使用不同云账号
  1. 服务端点发现‌
  • 自动解析云服务的 API 端点(如 OpenAI 的 regional endpoint),避免硬编码
  • 与 Spring Cloud 服务发现组件(如 Nacos)集成,实现动态路由
  1. 资源绑定‌
  • 将云存储(如 S3 Bucket)、向量数据库(如 Azure AI Search)等资源映射为 Spring Bean,直接注入业务代码

二、技术实现
1‌. 绑定协议‌

  • 基于 Spring Cloud Bindings 规范,扩展支持 AI 领域特有资源类型(如 VectorStore、EmbeddingModel)
  • 通过 @EnableAiCloudBindings 注解激活模块
  1. 配置示例‌
# application.yml
spring:
  cloud:
    bindings:
      openai-api:
        type: ai-model
        provider: azure  # 自动使用AZURE_OPENAI_KEY环境变量
      pinecone-store:
        type: vector-store
        provider: aws    # 自动绑定AWS Bedrock的Pinecone服务

支持通过 spring.cloud.bindings.* 前缀覆盖默认行为

spring-ai-model-chat-memory

spring-ai-model-chat-memory 是 Spring AI 框架中实现对话记忆(Chat Memory)功能的核心模块,用于解决大语言模型(LLM)无状态问题,支持多轮对话的上下文管理。
一、核心功能
‌1. 上下文维护‌

  • 通过 ChatMemory 接口自动存储和检索历史消息(包括 UserMessage、SystemMessage、AssistantMessage),解决 LLM 单次请求无记忆的问题
  • 默认实现 MessageWindowChatMemory 采用滑动窗口机制,保留最近 20 条消息(可配置)

2‌. 存储扩展‌

  • 支持多种持久化方案:
    ‌ - 内存存储‌:默认 InMemoryChatMemoryRepository(开发环境适用)
    ‌ - JDBC 存储‌:通过 spring-ai-starter-model-chat-memory-repository-jdbc 将对话记录保存到 MySQL 等关系型数据库
    -‌ NoSQL 存储‌:兼容 Cassandra、Neo4j 等

3‌. 高级特性‌

  • 向量搜索增强:VectorStoreChatMemoryAdvisor 支持基于语义相似度的历史消息检索
  • 动态清理策略:可配置按消息数量、时间窗口或 Token 限制自动清理旧消息

二、技术实现
1‌. 核心接口‌

public interface ChatMemory {
    void add(Message message);  // 添加消息
    List<Message> get();        // 获取当前对话上下文
    void clear();               // 清空记忆
}

开发者可通过 @Autowired 直接注入使用

2‌. 配置示例‌

Copy Code
# application.yml
spring:
  ai:
    chat:
      memory:
        type: message-window  # 使用滑动窗口策略
        size: 10              # 窗口大小
        storage: jdbc         # 持久化方式

支持通过属性文件灵活调整记忆策略

源代码研究

1. BOM

<dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-bom</artifactId>
            <version>1.0.0-SNAPSHOT</version>
            <type>pom</type>
            <scope>import</scope>
        </dependency>
    </dependencies>
</dependencyManagement>

2. Spring AI Configure

spring.ai.anthropic
spring.ai.azure.openai
spring.ai.bedrock.aws
spring.ai.chat
org.springframework.boot.actuate.autoconfigure.observation.ObservationAutoConfiguration
spring.ai.huggingface.chat
spring.ai.image.observations
spring.ai.minimax
spring.ai.mistralai
spring.ai.moonshot
spring.ai.oci.genai
spring.ai.ollama
spring.ai.openai
spring.ai.postgresml
spring.ai.qianfan
spring.ai.retry
spring.ai.stabilityai
spring.ai.embedding.transformer
spring.ai.vectorstore
spring.ai.vertex
spring.ai.watsonx
spring.ai.zhipuai

3. ChatModel API

在这里插入图片描述


public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel {

	default String call(String message) {
		Prompt prompt = new Prompt(new UserMessage(message));
		Generation generation = call(prompt).getResult();
		return (generation != null) ? generation.getOutput().getText() : "";
	}

	default String call(Message... messages) {
		Prompt prompt = new Prompt(Arrays.asList(messages));
		Generation generation = call(prompt).getResult();
		return (generation != null) ? generation.getOutput().getText() : "";
	}

	@Override
	ChatResponse call(Prompt prompt);

	default ChatOptions getDefaultOptions() {
		return ChatOptions.builder().build();
	}

	default Flux<ChatResponse> stream(Prompt prompt) {
		throw new UnsupportedOperationException("streaming is not supported");
	}

}

例子:Ollama Chat
在这里插入图片描述
Ollama API客户端
在这里插入图片描述

public class OllamaChatModel implements ChatModel {

	private static final Logger logger = LoggerFactory.getLogger(OllamaChatModel.class);

	private static final String DONE = "done";

	private static final String METADATA_PROMPT_EVAL_COUNT = "prompt-eval-count";

	private static final String METADATA_EVAL_COUNT = "eval-count";

	private static final String METADATA_CREATED_AT = "created-at";

	private static final String METADATA_TOTAL_DURATION = "total-duration";

	private static final String METADATA_LOAD_DURATION = "load-duration";

	private static final String METADATA_PROMPT_EVAL_DURATION = "prompt-eval-duration";

	private static final String METADATA_EVAL_DURATION = "eval-duration";

	private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

	private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();

	private final OllamaApi chatApi;

	private final OllamaOptions defaultOptions;

	private final ObservationRegistry observationRegistry;

	private final OllamaModelManager modelManager;

	private final ToolCallingManager toolCallingManager;

	/**
	 * The tool execution eligibility predicate used to determine if a tool can be
	 * executed.
	 */
	private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;

	private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

	public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
			ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
		this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions,
				new DefaultToolExecutionEligibilityPredicate());
	}

	public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
			ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
			ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
		Assert.notNull(ollamaApi, "ollamaApi must not be null");
		Assert.notNull(defaultOptions, "defaultOptions must not be null");
		Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
		Assert.notNull(observationRegistry, "observationRegistry must not be null");
		Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
		Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null");
		this.chatApi = ollamaApi;
		this.defaultOptions = defaultOptions;
		this.toolCallingManager = toolCallingManager;
		this.observationRegistry = observationRegistry;
		this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
		this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
		initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
	}

	public static Builder builder() {
		return new Builder();
	}

	static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse previousChatResponse) {
		Assert.notNull(response, "OllamaApi.ChatResponse must not be null");

		DefaultUsage newUsage = getDefaultUsage(response);
		Integer promptTokens = newUsage.getPromptTokens();
		Integer generationTokens = newUsage.getCompletionTokens();
		int totalTokens = newUsage.getTotalTokens();

		Duration evalDuration = response.getEvalDuration();
		Duration promptEvalDuration = response.getPromptEvalDuration();
		Duration loadDuration = response.getLoadDuration();
		Duration totalDuration = response.getTotalDuration();

		if (previousChatResponse != null && previousChatResponse.getMetadata() != null) {
			if (previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION) != null) {
				evalDuration = evalDuration.plus(previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION));
			}
			if (previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION) != null) {
				promptEvalDuration = promptEvalDuration
					.plus(previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION));
			}
			if (previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION) != null) {
				loadDuration = loadDuration.plus(previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION));
			}
			if (previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION) != null) {
				totalDuration = totalDuration.plus(previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION));
			}
			if (previousChatResponse.getMetadata().getUsage() != null) {
				promptTokens += previousChatResponse.getMetadata().getUsage().getPromptTokens();
				generationTokens += previousChatResponse.getMetadata().getUsage().getCompletionTokens();
				totalTokens += previousChatResponse.getMetadata().getUsage().getTotalTokens();
			}
		}

		DefaultUsage aggregatedUsage = new DefaultUsage(promptTokens, generationTokens, totalTokens);

		return ChatResponseMetadata.builder()
			.usage(aggregatedUsage)
			.model(response.model())
			.keyValue(METADATA_CREATED_AT, response.createdAt())
			.keyValue(METADATA_EVAL_DURATION, evalDuration)
			.keyValue(METADATA_EVAL_COUNT, aggregatedUsage.getCompletionTokens().intValue())
			.keyValue(METADATA_LOAD_DURATION, loadDuration)
			.keyValue(METADATA_PROMPT_EVAL_DURATION, promptEvalDuration)
			.keyValue(METADATA_PROMPT_EVAL_COUNT, aggregatedUsage.getPromptTokens().intValue())
			.keyValue(METADATA_TOTAL_DURATION, totalDuration)
			.keyValue(DONE, response.done())
			.build();
	}

	private static DefaultUsage getDefaultUsage(OllamaApi.ChatResponse response) {
		return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0),
				Optional.ofNullable(response.evalCount()).orElse(0));
	}

	@Override
	public ChatResponse call(Prompt prompt) {
		// Before moving any further, build the final request Prompt,
		// merging runtime and default options.
		Prompt requestPrompt = buildRequestPrompt(prompt);
		return this.internalCall(requestPrompt, null);
	}

	private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

		OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false);

		ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
			.prompt(prompt)
			.provider(OllamaApiConstants.PROVIDER_NAME)
			.build();

		ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
			.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
					this.observationRegistry)
			.observe(() -> {

				OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);

				List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
						: ollamaResponse.message()
							.toolCalls()
							.stream()
							.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
									ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
							.toList();

				var assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls);

				ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
				if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
					generationMetadata = ChatGenerationMetadata.builder()
						.finishReason(ollamaResponse.doneReason())
						.build();
				}

				var generator = new Generation(assistantMessage, generationMetadata);
				ChatResponse chatResponse = new ChatResponse(List.of(generator),
						from(ollamaResponse, previousChatResponse));

				observationContext.setResponse(chatResponse);

				return chatResponse;

			});

		if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
			var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
			if (toolExecutionResult.returnDirect()) {
				// Return tool execution result directly to the client.
				return ChatResponse.builder()
					.from(response)
					.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
					.build();
			}
			else {
				// Send the tool execution result back to the model.
				return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
						response);
			}
		}

		return response;
	}

	@Override
	public Flux<ChatResponse> stream(Prompt prompt) {
		// Before moving any further, build the final request Prompt,
		// merging runtime and default options.
		Prompt requestPrompt = buildRequestPrompt(prompt);
		return this.internalStream(requestPrompt, null);
	}

	private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
		return Flux.deferContextual(contextView -> {
			OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true);

			final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
				.prompt(prompt)
				.provider(OllamaApiConstants.PROVIDER_NAME)
				.build();

			Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
					this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
					this.observationRegistry);

			observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

			Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(request);

			Flux<ChatResponse> chatResponse = ollamaResponse.map(chunk -> {
				String content = (chunk.message() != null) ? chunk.message().content() : "";

				List<AssistantMessage.ToolCall> toolCalls = List.of();

				// Added null checks to prevent NPE when accessing tool calls
				if (chunk.message() != null && chunk.message().toolCalls() != null) {
					toolCalls = chunk.message()
						.toolCalls()
						.stream()
						.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
								ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
						.toList();
				}

				var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);

				ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
				if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
					generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build();
				}

				var generator = new Generation(assistantMessage, generationMetadata);
				return new ChatResponse(List.of(generator), from(chunk, previousChatResponse));
			});

			// @formatter:off
			Flux<ChatResponse> chatResponseFlux = chatResponse.flatMap(response -> {
				if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
					// FIXME: bounded elastic needs to be used since tool calling
					//  is currently only synchronous
					return Flux.defer(() -> {
						var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
						if (toolExecutionResult.returnDirect()) {
							// Return tool execution result directly to the client.
							return Flux.just(ChatResponse.builder().from(response)
									.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
									.build());
						}
						else {
							// Send the tool execution result back to the model.
							return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
									response);
						}
					}).subscribeOn(Schedulers.boundedElastic());
				}
				else {
					return Flux.just(response);
				}
			})
			.doOnError(observation::error)
			.doFinally(s ->
				observation.stop()
			)
			.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
			// @formatter:on

			return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
		});
	}

	Prompt buildRequestPrompt(Prompt prompt) {
		// Process runtime options
		OllamaOptions runtimeOptions = null;
		if (prompt.getOptions() != null) {
			if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
				runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
						OllamaOptions.class);
			}
			else {
				runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
						OllamaOptions.class);
			}
		}

		// Define request options by merging runtime options and default options
		OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
				OllamaOptions.class);
		// Merge @JsonIgnore-annotated options explicitly since they are ignored by
		// Jackson, used by ModelOptionsUtils.
		if (runtimeOptions != null) {
			requestOptions.setInternalToolExecutionEnabled(
					ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
							this.defaultOptions.getInternalToolExecutionEnabled()));
			requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
					this.defaultOptions.getToolNames()));
			requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
					this.defaultOptions.getToolCallbacks()));
			requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
					this.defaultOptions.getToolContext()));
		}
		else {
			requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
			requestOptions.setToolNames(this.defaultOptions.getToolNames());
			requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
			requestOptions.setToolContext(this.defaultOptions.getToolContext());
		}

		// Validate request options
		if (!StringUtils.hasText(requestOptions.getModel())) {
			throw new IllegalArgumentException("model cannot be null or empty");
		}

		ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());

		return new Prompt(prompt.getInstructions(), requestOptions);
	}

	/**
	 * Package access for testing.
	 */
	OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {

		List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
			if (message instanceof UserMessage userMessage) {
				var messageBuilder = OllamaApi.Message.builder(Role.USER).content(message.getText());
				if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
					messageBuilder.images(
							userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
				}
				return List.of(messageBuilder.build());
			}
			else if (message instanceof SystemMessage systemMessage) {
				return List.of(OllamaApi.Message.builder(Role.SYSTEM).content(systemMessage.getText()).build());
			}
			else if (message instanceof AssistantMessage assistantMessage) {
				List<ToolCall> toolCalls = null;
				if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
					toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
						var function = new ToolCallFunction(toolCall.name(),
								JsonParser.fromJson(toolCall.arguments(), new TypeReference<>() {
								}));
						return new ToolCall(function);
					}).toList();
				}
				return List.of(OllamaApi.Message.builder(Role.ASSISTANT)
					.content(assistantMessage.getText())
					.toolCalls(toolCalls)
					.build());
			}
			else if (message instanceof ToolResponseMessage toolMessage) {
				return toolMessage.getResponses()
					.stream()
					.map(tr -> OllamaApi.Message.builder(Role.TOOL).content(tr.responseData()).build())
					.toList();
			}
			throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
		}).flatMap(List::stream).toList();

		OllamaOptions requestOptions = (OllamaOptions) prompt.getOptions();

		OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel())
			.stream(stream)
			.messages(ollamaMessages)
			.options(requestOptions);

		if (requestOptions.getFormat() != null) {
			requestBuilder.format(requestOptions.getFormat());
		}

		if (requestOptions.getKeepAlive() != null) {
			requestBuilder.keepAlive(requestOptions.getKeepAlive());
		}

		List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
		if (!CollectionUtils.isEmpty(toolDefinitions)) {
			requestBuilder.tools(this.getTools(toolDefinitions));
		}

		return requestBuilder.build();
	}

	private String fromMediaData(Object mediaData) {
		if (mediaData instanceof byte[] bytes) {
			return Base64.getEncoder().encodeToString(bytes);
		}
		else if (mediaData instanceof String text) {
			return text;
		}
		else {
			throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
		}

	}

	private List<ChatRequest.Tool> getTools(List<ToolDefinition> toolDefinitions) {
		return toolDefinitions.stream().map(toolDefinition -> {
			var tool = new ChatRequest.Tool.Function(toolDefinition.name(), toolDefinition.description(),
					toolDefinition.inputSchema());
			return new ChatRequest.Tool(tool);
		}).toList();
	}

	@Override
	public ChatOptions getDefaultOptions() {
		return OllamaOptions.fromOptions(this.defaultOptions);
	}

	/**
	 * Pull the given model into Ollama based on the specified strategy.
	 */
	private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
		if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
			this.modelManager.pullModel(model, pullModelStrategy);
		}
	}

	/**
	 * Use the provided convention for reporting observation data
	 * @param observationConvention The provided convention
	 */
	public void setObservationConvention(ChatModelObservationConvention observationConvention) {
		Assert.notNull(observationConvention, "observationConvention cannot be null");
		this.observationConvention = observationConvention;
	}

	public static final class Builder {

		private OllamaApi ollamaApi;

		private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();

		private ToolCallingManager toolCallingManager;

		private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();

		private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

		private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

		private Builder() {
		}

		public Builder ollamaApi(OllamaApi ollamaApi) {
			this.ollamaApi = ollamaApi;
			return this;
		}

		public Builder defaultOptions(OllamaOptions defaultOptions) {
			this.defaultOptions = defaultOptions;
			return this;
		}

		public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
			this.toolCallingManager = toolCallingManager;
			return this;
		}

		public Builder toolExecutionEligibilityPredicate(
				ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
			this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
			return this;
		}

		public Builder observationRegistry(ObservationRegistry observationRegistry) {
			this.observationRegistry = observationRegistry;
			return this;
		}

		public Builder modelManagementOptions(ModelManagementOptions modelManagementOptions) {
			this.modelManagementOptions = modelManagementOptions;
			return this;
		}

		public OllamaChatModel build() {
			if (this.toolCallingManager != null) {
				return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
						this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
			}
			return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
					this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
		}

	}

}

4. Embeddings Model API

Embeddings是文本、图像或视频的数字表示形式,用于捕获输入之间的关系。

Embeddings的工作原理是将文本、图像和视频转换为浮点数数组(称为向量)。 这些矢量旨在捕获文本、图像和视频的含义。 嵌入数组的长度称为向量的维数。
在这里插入图片描述

public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {

	@Override
	EmbeddingResponse call(EmbeddingRequest request);

	/**
	 * Embeds the given text into a vector.
	 * @param text the text to embed.
	 * @return the embedded vector.
	 */
	default float[] embed(String text) {
		Assert.notNull(text, "Text must not be null");
		List<float[]> response = this.embed(List.of(text));
		return response.iterator().next();
	}

	/**
	 * Embeds the given document's content into a vector.
	 * @param document the document to embed.
	 * @return the embedded vector.
	 */
	float[] embed(Document document);

	/**
	 * Embeds a batch of texts into vectors.
	 * @param texts list of texts to embed.
	 * @return list of embedded vectors.
	 */
	default List<float[]> embed(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
			.getResults()
			.stream()
			.map(Embedding::getOutput)
			.toList();
	}

	/**
	 * Embeds a batch of {@link Document}s into vectors based on a
	 * {@link BatchingStrategy}.
	 * @param documents list of {@link Document}s.
	 * @param options {@link EmbeddingOptions}.
	 * @param batchingStrategy {@link BatchingStrategy}.
	 * @return a list of float[] that represents the vectors for the incoming
	 * {@link Document}s. The returned list is expected to be in the same order of the
	 * {@link Document} list.
	 */
	default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
		Assert.notNull(documents, "Documents must not be null");
		List<float[]> embeddings = new ArrayList<>(documents.size());
		List<List<Document>> batch = batchingStrategy.batch(documents);
		for (List<Document> subBatch : batch) {
			List<String> texts = subBatch.stream().map(Document::getText).toList();
			EmbeddingRequest request = new EmbeddingRequest(texts, options);
			EmbeddingResponse response = this.call(request);
			for (int i = 0; i < subBatch.size(); i++) {
				embeddings.add(response.getResults().get(i).getOutput());
			}
		}
		Assert.isTrue(embeddings.size() == documents.size(),
				"Embeddings must have the same number as that of the documents");
		return embeddings;
	}

	/**
	 * Embeds a batch of texts into vectors and returns the {@link EmbeddingResponse}.
	 * @param texts list of texts to embed.
	 * @return the embedding response.
	 */
	default EmbeddingResponse embedForResponse(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()));
	}

	/**
	 * Get the number of dimensions of the embedded vectors. Note that by default, this
	 * method will call the remote Embedding endpoint to get the dimensions of the
	 * embedded vectors. If the dimensions are known ahead of time, it is recommended to
	 * override this method.
	 * @return the number of dimensions of the embedded vectors.
	 */
	default int dimensions() {
		return embed("Test String").length;
	}

}

样例:OllamaEmbeddingModel
OllamaEmbeddingModel 是 Spring AI 框架中用于与 Ollama 嵌入模型交互的组件实现,主要功能是将文本转换为语义向量(embedding)用于各类NLP任务。

  • 核心功能
    ‌1. 文本向量化‌
    • 将输入文本转换为高维浮点数向量(如768维),用于衡量文本间的语义相关性
    • 支持通过Ollama原生API或OpenAI兼容API调用,例如:
    curl http://localhost:11434/api/embed
    
    2‌. 模型支持‌
    • 兼容多种嵌入模型如 bge-m3(多语言/多粒度)、nomic-embed-text(长文本优化)等
    • 需通过 ollama pull <模型名> 预先下载模型
public class OllamaEmbeddingModel extends AbstractEmbeddingModel {

	private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

	private final OllamaApi ollamaApi;

	private final OllamaOptions defaultOptions;

	private final ObservationRegistry observationRegistry;

	private final OllamaModelManager modelManager;

	private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

	public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
			ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
		Assert.notNull(ollamaApi, "ollamaApi must not be null");
		Assert.notNull(defaultOptions, "options must not be null");
		Assert.notNull(observationRegistry, "observationRegistry must not be null");
		Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");

		this.ollamaApi = ollamaApi;
		this.defaultOptions = defaultOptions;
		this.observationRegistry = observationRegistry;
		this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);

		initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
	}

	public static Builder builder() {
		return new Builder();
	}

	@Override
	public float[] embed(Document document) {
		return embed(document.getText());
	}

	@Override
	public EmbeddingResponse call(EmbeddingRequest request) {
		Assert.notEmpty(request.getInstructions(), "At least one text is required!");

		// Before moving any further, build the final request EmbeddingRequest,
		// merging runtime and default options.
		EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);

		OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(embeddingRequest);

		var observationContext = EmbeddingModelObservationContext.builder()
			.embeddingRequest(request)
			.provider(OllamaApiConstants.PROVIDER_NAME)
			.build();

		return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
			.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
					this.observationRegistry)
			.observe(() -> {
				EmbeddingsResponse response = this.ollamaApi.embed(ollamaEmbeddingRequest);

				AtomicInteger indexCounter = new AtomicInteger(0);

				List<Embedding> embeddings = response.embeddings()
					.stream()
					.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
					.toList();

				EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(),
						getDefaultUsage(response));

				EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, embeddingResponseMetadata);

				observationContext.setResponse(embeddingResponse);

				return embeddingResponse;
			});
	}

	private DefaultUsage getDefaultUsage(OllamaApi.EmbeddingsResponse response) {
		return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0), 0);
	}

	EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
		// Process runtime options
		OllamaOptions runtimeOptions = null;
		if (embeddingRequest.getOptions() != null) {
			runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
					OllamaOptions.class);
		}

		// Define request options by merging runtime options and default options
		OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
				OllamaOptions.class);

		// Validate request options
		if (!StringUtils.hasText(requestOptions.getModel())) {
			throw new IllegalArgumentException("model cannot be null or empty");
		}

		return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
	}

	/**
	 * Package access for testing.
	 */
	OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(EmbeddingRequest embeddingRequest) {
		OllamaOptions requestOptions = (OllamaOptions) embeddingRequest.getOptions();
		return new OllamaApi.EmbeddingsRequest(requestOptions.getModel(), embeddingRequest.getInstructions(),
				DurationParser.parse(requestOptions.getKeepAlive()),
				OllamaOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate());
	}

	/**
	 * Pull the given model into Ollama based on the specified strategy.
	 */
	private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
		if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
			this.modelManager.pullModel(model, pullModelStrategy);
		}
	}

	/**
	 * Use the provided convention for reporting observation data
	 * @param observationConvention The provided convention
	 */
	public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
		Assert.notNull(observationConvention, "observationConvention cannot be null");
		this.observationConvention = observationConvention;
	}

	public static class DurationParser {

		private static final Pattern PATTERN = Pattern.compile("(-?\\d+)(ms|s|m|h)");

		public static Duration parse(String input) {

			if (!StringUtils.hasText(input)) {
				return null;
			}

			Matcher matcher = PATTERN.matcher(input);

			if (matcher.matches()) {
				long value = Long.parseLong(matcher.group(1));
				String unit = matcher.group(2);

				return switch (unit) {
					case "ms" -> Duration.ofMillis(value);
					case "s" -> Duration.ofSeconds(value);
					case "m" -> Duration.ofMinutes(value);
					case "h" -> Duration.ofHours(value);
					default -> throw new IllegalArgumentException("Unsupported time unit: " + unit);
				};
			}
			else {
				throw new IllegalArgumentException("Invalid duration format: " + input);
			}
		}

	}

	public static final class Builder {

		private OllamaApi ollamaApi;

		private OllamaOptions defaultOptions = OllamaOptions.builder()
			.model(OllamaModel.MXBAI_EMBED_LARGE.id())
			.build();

		private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

		private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

		private Builder() {
		}

		public Builder ollamaApi(OllamaApi ollamaApi) {
			this.ollamaApi = ollamaApi;
			return this;
		}

		public Builder defaultOptions(OllamaOptions defaultOptions) {
			this.defaultOptions = defaultOptions;
			return this;
		}

		public Builder observationRegistry(ObservationRegistry observationRegistry) {
			this.observationRegistry = observationRegistry;
			return this;
		}

		public Builder modelManagementOptions(ModelManagementOptions modelManagementOptions) {
			this.modelManagementOptions = modelManagementOptions;
			return this;
		}

		public OllamaEmbeddingModel build() {
			return new OllamaEmbeddingModel(this.ollamaApi, this.defaultOptions, this.observationRegistry,
					this.modelManagementOptions);
		}

	}

}

5. Image Model

@FunctionalInterface
public interface ImageModel extends Model<ImagePrompt, ImageResponse> {
	ImageResponse call(ImagePrompt request);
}

6. Audio Model

7. Chat Memory

public interface ChatMemory {

	String DEFAULT_CONVERSATION_ID = "default";

	/**
	 * The key to retrieve the chat memory conversation id from the context.
	 */
	String CONVERSATION_ID = "chat_memory_conversation_id";

	/**
	 * Save the specified message in the chat memory for the specified conversation.
	 */
	default void add(String conversationId, Message message) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		Assert.notNull(message, "message cannot be null");
		this.add(conversationId, List.of(message));
	}

	/**
	 * Save the specified messages in the chat memory for the specified conversation.
	 */
	void add(String conversationId, List<Message> messages);

	/**
	 * Get the messages in the chat memory for the specified conversation.
	 */
	List<Message> get(String conversationId);

	/**
	 * Clear the chat memory for the specified conversation.
	 */
	void clear(String conversationId);

}

样例:

public final class MessageWindowChatMemory implements ChatMemory {

	private static final int DEFAULT_MAX_MESSAGES = 20;

	private final ChatMemoryRepository chatMemoryRepository;

	private final int maxMessages;

	private MessageWindowChatMemory(ChatMemoryRepository chatMemoryRepository, int maxMessages) {
		Assert.notNull(chatMemoryRepository, "chatMemoryRepository cannot be null");
		Assert.isTrue(maxMessages > 0, "maxMessages must be greater than 0");
		this.chatMemoryRepository = chatMemoryRepository;
		this.maxMessages = maxMessages;
	}

	@Override
	public void add(String conversationId, List<Message> messages) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		Assert.notNull(messages, "messages cannot be null");
		Assert.noNullElements(messages, "messages cannot contain null elements");

		List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
		List<Message> processedMessages = process(memoryMessages, messages);
		this.chatMemoryRepository.saveAll(conversationId, processedMessages);
	}

	@Override
	public List<Message> get(String conversationId) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		return this.chatMemoryRepository.findByConversationId(conversationId);
	}

	@Override
	public void clear(String conversationId) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		this.chatMemoryRepository.deleteByConversationId(conversationId);
	}

	private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
		List<Message> processedMessages = new ArrayList<>();

		Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
		boolean hasNewSystemMessage = newMessages.stream()
			.filter(SystemMessage.class::isInstance)
			.anyMatch(message -> !memoryMessagesSet.contains(message));

		memoryMessages.stream()
			.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
			.forEach(processedMessages::add);

		processedMessages.addAll(newMessages);

		if (processedMessages.size() <= this.maxMessages) {
			return processedMessages;
		}

		int messagesToRemove = processedMessages.size() - this.maxMessages;

		List<Message> trimmedMessages = new ArrayList<>();
		int removed = 0;
		for (Message message : processedMessages) {
			if (message instanceof SystemMessage || removed >= messagesToRemove) {
				trimmedMessages.add(message);
			}
			else {
				removed++;
			}
		}

		return trimmedMessages;
	}

	public static Builder builder() {
		return new Builder();
	}

	public static final class Builder {

		private ChatMemoryRepository chatMemoryRepository;

		private int maxMessages = DEFAULT_MAX_MESSAGES;

		private Builder() {
		}

		public Builder chatMemoryRepository(ChatMemoryRepository chatMemoryRepository) {
			this.chatMemoryRepository = chatMemoryRepository;
			return this;
		}

		public Builder maxMessages(int maxMessages) {
			this.maxMessages = maxMessages;
			return this;
		}

		public MessageWindowChatMemory build() {
			if (this.chatMemoryRepository == null) {
				this.chatMemoryRepository = new InMemoryChatMemoryRepository();
			}
			return new MessageWindowChatMemory(this.chatMemoryRepository, this.maxMessages);
		}

	}

}

8. Tool Calling

public interface ToolCallback {

	/**
	 * Definition used by the AI model to determine when and how to call the tool.
	 */
	ToolDefinition getToolDefinition();

	/**
	 * Metadata providing additional information on how to handle the tool.
	 */
	default ToolMetadata getToolMetadata() {
		return ToolMetadata.builder().build();
	}

	/**
	 * Execute tool with the given input and return the result to send back to the AI
	 * model.
	 */
	String call(String toolInput);

	/**
	 * Execute tool with the given input and context, and return the result to send back
	 * to the AI model.
	 */
	default String call(String toolInput, @Nullable ToolContext tooContext) {
		if (tooContext != null && !tooContext.getContext().isEmpty()) {
			throw new UnsupportedOperationException("Tool context is not supported!");
		}
		return call(toolInput);
	}

}

网站公告

今日签到

点亮在社区的每一天
去签到