Springai 指定模型的三种方式(Ollama)
在实际开发中,Ollama 支持三种常用的模型指定方式:
1. 从 yml 配置读取默认模型
注意: 这是最基础、最推荐的方式,必须先配置好才能用自动注入的 OllamaChatModel。
spring:
ai:
ollama:
base-url: http://localhost:11434
chat:
options:
model: deepseek-r1:7b
@Autowired
private OllamaChatModel chatModel;
// 直接调用 chatModel.call(...) 即用默认模型
2. Prompt 临时指定模型
通过 Prompt 构造时传入 OllamaOptions,可临时切换模型:
import org.springframework.ai.ollama.api.OllamaOptions;
Prompt prompt = new Prompt(
messageList,
OllamaOptions.builder()
.model("qwen2.5-vl") // 临时指定模型
.build()
);
return chatModel.stream(prompt);
3. 创建多个 OllamaChatModel 动态切换
可在配置类中为不同模型创建多个 Bean,或用工厂模式动态切换:
@Bean
public OllamaChatModel ollamaQwenModel() {
OllamaApi ollamaApi = OllamaApi.builder().baseUrl("http://localhost:11434").build();
// 解析参数
OllamaOptions.Builder optionsBuilder = OllamaOptions.builder()
.model("qwen2.5-vl:3b");
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(optionsBuilder.build())
.build();
}
@Bean
public OllamaChatModel ollamaLlamaModel() {
OllamaApi ollamaApi = OllamaApi.builder().baseUrl("http://localhost:11434").build();
// 解析参数
OllamaOptions.Builder optionsBuilder = OllamaOptions.builder()
.model("llama2:7b");
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(optionsBuilder.build())
.build();
}
或通过自定义工厂类,根据参数动态返回不同模型实例:
public class DynamicModelFactory {
public OllamaChatModel getModelByName(String modelName) {
// ...根据modelName返回不同OllamaChatModel实例...
}
}
接口调用时根据参数动态切换:
OllamaChatModel model = dynamicModelFactory.getModelByName(modelName);
return model.stream(prompt);
建议所有模型还是维护到数据库,因为大部分模型特别是相同供应商的调用方式都一样
CREATE TABLE `ai_model` (
`id` bigint NOT NULL AUTO_INCREMENT,
`vendor` varchar(64) NOT NULL COMMENT '供应商',
`icon` varchar(255) DEFAULT NULL COMMENT '图标URL',
`name` varchar(128) NOT NULL COMMENT '模型名称',
`api_key` varchar(255) DEFAULT NULL COMMENT '密钥',
`api_url` varchar(255) NOT NULL COMMENT '模型API地址',
`tags` varchar(255) DEFAULT NULL COMMENT '标签(推理、对话、图片、语音等,逗号分隔)',
`type` varchar(32) NOT NULL COMMENT '类型(对话、图片、音频、视频、量化)',
`status` tinyint NOT NULL DEFAULT 1 COMMENT '模型可用状态 1:可用 0:不可用',
`description` text COMMENT '模型描述',
`params` json DEFAULT NULL COMMENT '模型参数(如温度等)',
`create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_german2_ci;
DynamicModelFactory.java
/**
* 动态模型工厂,项目启动时缓存所有模型
*/
@Component
public class DynamicModelFactory {
@Autowired
private AiModelRepository aiModelRepository;
private final HashMap<String, MyModel> modelHashMap = new HashMap<>();
@PostConstruct
public void init() {
List<AiModel> models = aiModelRepository.findAll();
ObjectMapper objectMapper = new ObjectMapper();
for (AiModel m : models) {
MyModel myModel = new MyModel();
myModel.vendor = m.getVendor();
myModel.name = m.getName();
myModel.apiKey = m.getApiKey();
myModel.apiUrl = m.getApiUrl();
myModel.type = m.getType();
myModel.status = m.getStatus();
myModel.params = m.getParams();
// 这里只实现Ollama,后续可扩展其他供应商
if ("ollama".equalsIgnoreCase(m.getVendor())) {
OllamaApi ollamaApi = OllamaApi.builder().baseUrl(m.getApiUrl()).build();
// 解析参数
OllamaOptions.Builder optionsBuilder = OllamaOptions.builder()
.model(m.getName());
if (m.getParams() != null && !m.getParams().isEmpty()) {
try {
Map<String, Object> paramMap = objectMapper.readValue(m.getParams(), Map.class);
if (paramMap.containsKey("temperature")) {
optionsBuilder.temperature(Double.parseDouble(paramMap.get("temperature").toString()));
}
// 可扩展更多参数
} catch (Exception ignored) {}
}
myModel.chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(optionsBuilder.build())
.build();
}
// TODO: 其他供应商实现
modelHashMap.put(m.getName(), myModel);
}
}
public MyModel getModelByName(String name) {
return modelHashMap.get(name);
}
public void refreshModel(String modelName) {
AiModel m = aiModelRepository.findByName(modelName);
if (m == null) return;
MyModel myModel = new MyModel();
myModel.vendor = m.getVendor();
myModel.name = m.getName();
myModel.apiKey = m.getApiKey();
myModel.apiUrl = m.getApiUrl();
myModel.type = m.getType();
myModel.status = m.getStatus();
myModel.params = m.getParams();
if ("ollama".equalsIgnoreCase(m.getVendor())) {
OllamaApi ollamaApi = OllamaApi.builder().baseUrl(m.getApiUrl()).build();
OllamaOptions.Builder optionsBuilder = OllamaOptions.builder().model(m.getName());
if (m.getParams() != null && !m.getParams().isEmpty()) {
try {
Map<String, Object> paramMap = new ObjectMapper().readValue(m.getParams(), Map.class);
if (paramMap.containsKey("temperature")) {
optionsBuilder.temperature(Double.parseDouble(paramMap.get("temperature").toString()));
}
} catch (Exception ignored) {}
}
myModel.chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(optionsBuilder.build())
.build();
}
// TODO: 其他供应商实现
modelHashMap.put(m.getName(), myModel);
}
public static class MyModel {
@Schema(description = "供应商")
private String vendor;
@Schema(description = "模型名称")
private String name;
@Schema(description = "密钥")
private String apiKey;
@Schema(description = "模型API地址")
private String apiUrl;
@Schema(description = "类型(对话、图片、音频、视频、量化)")
private String type;
@Schema(description = "模型可用状态 1:可用 0:不可用")
private Integer status;
@Schema(description = "模型参数(如温度等,json格式)")
private String params;
@Schema(description = "ollama对应的会话对象")
private OllamaChatModel chatModel;
// TODO: 其他供应商的会话ChatModel
public OllamaChatModel getChatModel() { return chatModel; }
public String getVendor() { return vendor; }
public String getName() { return name; }
public String getApiKey() { return apiKey; }
public String getApiUrl() { return apiUrl; }
public String getType() { return type; }
public Integer getStatus() { return status; }
public String getParams() { return params; }
}
}