【Springai】 2指定模型的三种方式(Ollama)

发布于:2025-07-03 ⋅ 阅读:(19) ⋅ 点赞:(0)

Springai 指定模型的三种方式(Ollama)

在实际开发中,Ollama 支持三种常用的模型指定方式:

1. 从 yml 配置读取默认模型

注意: 这是最基础、最推荐的方式,必须先配置好才能用自动注入的 OllamaChatModel。

spring:
  ai:
    ollama:
      base-url: http://localhost:11434
      chat:
        options:
          model: deepseek-r1:7b
@Autowired
private OllamaChatModel chatModel;
// 直接调用 chatModel.call(...) 即用默认模型

2. Prompt 临时指定模型

通过 Prompt 构造时传入 OllamaOptions,可临时切换模型:

import org.springframework.ai.ollama.api.OllamaOptions;

    Prompt prompt = new Prompt(
        messageList,
        OllamaOptions.builder()
            .model("qwen2.5-vl") // 临时指定模型
            .build()
    );
    return chatModel.stream(prompt);

3. 创建多个 OllamaChatModel 动态切换

可在配置类中为不同模型创建多个 Bean,或用工厂模式动态切换:

@Bean
public OllamaChatModel ollamaQwenModel() {
    OllamaApi ollamaApi = OllamaApi.builder().baseUrl("http://localhost:11434").build();
    // 解析参数
    OllamaOptions.Builder optionsBuilder = OllamaOptions.builder()
            .model("qwen2.5-vl:3b");
    return OllamaChatModel.builder()
            .ollamaApi(ollamaApi)
            .defaultOptions(optionsBuilder.build())
            .build();
}

@Bean
public OllamaChatModel ollamaLlamaModel() {
    OllamaApi ollamaApi = OllamaApi.builder().baseUrl("http://localhost:11434").build();
    // 解析参数
    OllamaOptions.Builder optionsBuilder = OllamaOptions.builder()
            .model("llama2:7b");
    return OllamaChatModel.builder()
            .ollamaApi(ollamaApi)
            .defaultOptions(optionsBuilder.build())
            .build();
}

或通过自定义工厂类,根据参数动态返回不同模型实例:

public class DynamicModelFactory {
    public OllamaChatModel getModelByName(String modelName) {
        // ...根据modelName返回不同OllamaChatModel实例...
    }
}

接口调用时根据参数动态切换:

OllamaChatModel model = dynamicModelFactory.getModelByName(modelName);
return model.stream(prompt);

建议所有模型还是维护到数据库,因为大部分模型特别是相同供应商的调用方式都一样

CREATE TABLE `ai_model` (
  `id` bigint NOT NULL AUTO_INCREMENT,
  `vendor` varchar(64) NOT NULL COMMENT '供应商',
  `icon` varchar(255) DEFAULT NULL COMMENT '图标URL',
  `name` varchar(128) NOT NULL COMMENT '模型名称',
  `api_key` varchar(255) DEFAULT NULL COMMENT '密钥',
  `api_url` varchar(255) NOT NULL COMMENT '模型API地址',
  `tags` varchar(255) DEFAULT NULL COMMENT '标签(推理、对话、图片、语音等,逗号分隔)',
  `type` varchar(32) NOT NULL COMMENT '类型(对话、图片、音频、视频、量化)',
  `status` tinyint NOT NULL DEFAULT 1 COMMENT '模型可用状态 1:可用 0:不可用',
  `description` text COMMENT '模型描述',
  `params` json DEFAULT NULL COMMENT '模型参数(如温度等)',
  `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,
  `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_german2_ci;

DynamicModelFactory.java

/**
 * 动态模型工厂,项目启动时缓存所有模型
 */
@Component
public class DynamicModelFactory {
    @Autowired
    private AiModelRepository aiModelRepository;

    private final HashMap<String, MyModel> modelHashMap = new HashMap<>();

    @PostConstruct
    public void init() {
        List<AiModel> models = aiModelRepository.findAll();
        ObjectMapper objectMapper = new ObjectMapper();
        for (AiModel m : models) {
            MyModel myModel = new MyModel();
            myModel.vendor = m.getVendor();
            myModel.name = m.getName();
            myModel.apiKey = m.getApiKey();
            myModel.apiUrl = m.getApiUrl();
            myModel.type = m.getType();
            myModel.status = m.getStatus();
            myModel.params = m.getParams();
            // 这里只实现Ollama,后续可扩展其他供应商
            if ("ollama".equalsIgnoreCase(m.getVendor())) {
                OllamaApi ollamaApi = OllamaApi.builder().baseUrl(m.getApiUrl()).build();
                // 解析参数
                OllamaOptions.Builder optionsBuilder = OllamaOptions.builder()
                        .model(m.getName());
                if (m.getParams() != null && !m.getParams().isEmpty()) {
                    try {
                        Map<String, Object> paramMap = objectMapper.readValue(m.getParams(), Map.class);
                        if (paramMap.containsKey("temperature")) {
                            optionsBuilder.temperature(Double.parseDouble(paramMap.get("temperature").toString()));
                        }
                        // 可扩展更多参数
                    } catch (Exception ignored) {}
                }
                myModel.chatModel = OllamaChatModel.builder()
                        .ollamaApi(ollamaApi)
                        .defaultOptions(optionsBuilder.build())
                        .build();
            }
            // TODO: 其他供应商实现
            modelHashMap.put(m.getName(), myModel);
        }
    }

    public MyModel getModelByName(String name) {
        return modelHashMap.get(name);
    }

    public void refreshModel(String modelName) {
        AiModel m = aiModelRepository.findByName(modelName);
        if (m == null) return;
        MyModel myModel = new MyModel();
        myModel.vendor = m.getVendor();
        myModel.name = m.getName();
        myModel.apiKey = m.getApiKey();
        myModel.apiUrl = m.getApiUrl();
        myModel.type = m.getType();
        myModel.status = m.getStatus();
        myModel.params = m.getParams();
        if ("ollama".equalsIgnoreCase(m.getVendor())) {
            OllamaApi ollamaApi = OllamaApi.builder().baseUrl(m.getApiUrl()).build();
            OllamaOptions.Builder optionsBuilder = OllamaOptions.builder().model(m.getName());
            if (m.getParams() != null && !m.getParams().isEmpty()) {
                try {
                    Map<String, Object> paramMap = new ObjectMapper().readValue(m.getParams(), Map.class);
                    if (paramMap.containsKey("temperature")) {
                        optionsBuilder.temperature(Double.parseDouble(paramMap.get("temperature").toString()));
                    }
                } catch (Exception ignored) {}
            }
            myModel.chatModel = OllamaChatModel.builder()
                    .ollamaApi(ollamaApi)
                    .defaultOptions(optionsBuilder.build())
                    .build();
        }
        // TODO: 其他供应商实现
        modelHashMap.put(m.getName(), myModel);
    }

    public static class MyModel {
        @Schema(description = "供应商")
        private String vendor;
        @Schema(description = "模型名称")
        private String name;
        @Schema(description = "密钥")
        private String apiKey;
        @Schema(description = "模型API地址")
        private String apiUrl;
        @Schema(description = "类型(对话、图片、音频、视频、量化)")
        private String type;
        @Schema(description = "模型可用状态 1:可用 0:不可用")
        private Integer status;
        @Schema(description = "模型参数(如温度等,json格式)")
        private String params;
        @Schema(description = "ollama对应的会话对象")
        private OllamaChatModel chatModel;
        // TODO: 其他供应商的会话ChatModel
        public OllamaChatModel getChatModel() { return chatModel; }
        public String getVendor() { return vendor; }
        public String getName() { return name; }
        public String getApiKey() { return apiKey; }
        public String getApiUrl() { return apiUrl; }
        public String getType() { return type; }
        public Integer getStatus() { return status; }
        public String getParams() { return params; }
    }
}