【Spring Boot 】Spring Boot + OpenAI API 万能集成模板,实现快速集成AI

发布于:2025-08-05 ⋅ 阅读:(10) ⋅ 点赞:(0)

一、核心架构设计

Spring Boot应用
OpenAI API
聊天对话
图像生成
文本嵌入
代码生成
缓存层
限流控制
日志审计
异常处理

二、万能模板实现

2.1 基础配置类

@Configuration
public class OpenAIConfig {

    @Value("${openai.api.key}")
    private String apiKey;

    @Value("${openai.api.url}")
    private String apiUrl = "https://api.openai.com/v1";

    @Bean
    public WebClient openaiWebClient() {
        return WebClient.builder()
                .baseUrl(apiUrl)
                .defaultHeader("Authorization", "Bearer " + apiKey)
                .defaultHeader("Content-Type", "application/json")
                .build();
    }

    @Bean
    public ObjectMapper objectMapper() {
        return new ObjectMapper()
                .registerModule(new JavaTimeModule())
                .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
    }
}

2.2 通用请求模板

@Service
public class OpenAIService {

    private final WebClient webClient;
    private final ObjectMapper objectMapper;
    private final RateLimiter rateLimiter = RateLimiter.create(3); // 每秒3次调用

    public OpenAIService(WebClient webClient, ObjectMapper objectMapper) {
        this.webClient = webClient;
        this.objectMapper = objectMapper;
    }

    public <T> Mono<T> sendRequest(String endpoint, Object request, Class<T> responseType) {
        return Mono.fromCallable(() -> objectMapper.writeValueAsString(request))
                .flatMap(requestBody -> {
                    if (!rateLimiter.tryAcquire()) {
                        return Mono.error(new RateLimitExceededException("OpenAI API rate limit exceeded"));
                    }
                    
                    return webClient.post()
                            .uri(endpoint)
                            .bodyValue(requestBody)
                            .retrieve()
                            .onStatus(HttpStatus::is4xxClientError, response -> 
                                handleClientError(response, endpoint))
                            .onStatus(HttpStatus::is5xxServerError, response -> 
                                handleServerError(response, endpoint))
                            .bodyToMono(String.class)
                            .flatMap(responseBody -> parseResponse(responseBody, responseType));
                })
                .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))
                .doOnError(e -> log.error("OpenAI API call failed", e))
                .doOnSuccess(response -> log.info("API call to {} succeeded", endpoint));
    }

    private <T> Mono<T> parseResponse(String responseBody, Class<T> responseType) {
        try {
            return Mono.just(objectMapper.readValue(responseBody, responseType));
        } catch (JsonProcessingException e) {
            return Mono.error(new OpenAIException("Failed to parse response", e));
        }
    }

    private Mono<? extends Throwable> handleClientError(ClientResponse response, String endpoint) {
        return response.bodyToMono(String.class)
                .flatMap(errorBody -> {
                    log.error("Client error for {}: {}", endpoint, errorBody);
                    return Mono.error(new OpenAIException("Client error: " + errorBody));
                });
    }

    private Mono<? extends Throwable> handleServerError(ClientResponse response, String endpoint) {
        return Mono.error(new OpenAIException("Server error for endpoint: " + endpoint));
    }
}

三、常用功能封装

3.1 聊天对话接口

public class ChatService {

    private final OpenAIService openAIService;

    public ChatService(OpenAIService openAIService) {
        this.openAIService = openAIService;
    }

    public Mono<String> chatCompletion(String prompt, String model) {
        ChatRequest request = new ChatRequest(
                model,
                List.of(new Message("user", prompt)),
                0.7, // temperature
                1000 // max tokens
        );

        return openAIService.sendRequest("/chat/completions", request, ChatResponse.class)
                .map(ChatResponse::getContent);
    }

    @Data
    @AllArgsConstructor
    private static class ChatRequest {
        private String model;
        private List<Message> messages;
        private double temperature;
        private int max_tokens;
    }

    @Data
    private static class Message {
        private final String role;
        private final String content;
    }

    @Data
    private static class ChatResponse {
        private List<Choice> choices;
        
        public String getContent() {
            return choices.get(0).getMessage().getContent();
        }
        
        @Data
        private static class Choice {
            private Message message;
        }
    }
}

3.2 图像生成接口

public class ImageService {

    private final OpenAIService openAIService;

    public ImageService(OpenAIService openAIService) {
        this.openAIService = openAIService;
    }

    public Mono<List<String>> generateImage(String prompt, int n, String size) {
        ImageRequest request = new ImageRequest(prompt, n, size);
        
        return openAIService.sendRequest("/images/generations", request, ImageResponse.class)
                .map(ImageResponse::getImageUrls);
    }

    @Data
    @AllArgsConstructor
    private static class ImageRequest {
        private String prompt;
        private int n;
        private String size;
    }

    @Data
    private static class ImageResponse {
        private List<ImageData> data;
        
        public List<String> getImageUrls() {
            return data.stream()
                    .map(ImageData::getUrl)
                    .collect(Collectors.toList());
        }
        
        @Data
        private static class ImageData {
            private String url;
        }
    }
}

3.3 文本嵌入接口

public class EmbeddingService {

    private final OpenAIService openAIService;

    public EmbeddingService(OpenAIService openAIService) {
        this.openAIService = openAIService;
    }

    public Mono<List<Double>> getEmbedding(String text, String model) {
        EmbeddingRequest request = new EmbeddingRequest(model, text);
        
        return openAIService.sendRequest("/embeddings", request, EmbeddingResponse.class)
                .map(EmbeddingResponse::getEmbedding);
    }

    @Data
    @AllArgsConstructor
    private static class EmbeddingRequest {
        private String model;
        private String input;
    }

    @Data
    private static class EmbeddingResponse {
        private List<EmbeddingData> data;
        
        public List<Double> getEmbedding() {
            return data.get(0).getEmbedding();
        }
        
        @Data
        private static class EmbeddingData {
            private List<Double> embedding;
        }
    }
}

四、高级功能扩展

4.1 带上下文的连续对话

public class ConversationService {

    private final OpenAIService openAIService;
    private final Map<String, List<Message>> conversationHistory = new ConcurrentHashMap<>();

    public ConversationService(OpenAIService openAIService) {
        this.openAIService = openAIService;
    }

    public Mono<String> continueConversation(String sessionId, String userMessage) {
        List<Message> history = conversationHistory.computeIfAbsent(sessionId, k -> new ArrayList<>());
        history.add(new Message("user", userMessage));
        
        // 限制历史记录长度
        if (history.size() > 10) {
            history = history.subList(history.size() - 10, history.size());
            conversationHistory.put(sessionId, history);
        }
        
        ChatRequest request = new ChatRequest(
                "gpt-4",
                history,
                0.7,
                1000
        );
        
        return openAIService.sendRequest("/chat/completions", request, ChatResponse.class)
                .map(response -> {
                    String assistantResponse = response.getContent();
                    history.add(new Message("assistant", assistantResponse));
                    return assistantResponse;
                });
    }
}

4.2 函数调用能力

public class FunctionCallService {

    private final OpenAIService openAIService;

    public FunctionCallService(OpenAIService openAIService) {
        this.openAIService = openAIService;
    }

    public Mono<FunctionCallResult> callFunction(String prompt, List<FunctionSpec> functions) {
        FunctionCallRequest request = new FunctionCallRequest(
                "gpt-4",
                List.of(new Message("user", prompt)),
                functions,
                0.7
        );
        
        return openAIService.sendRequest("/chat/completions", request, FunctionCallResponse.class)
                .map(response -> {
                    FunctionCall functionCall = response.getChoices().get(0).getMessage().getFunction_call();
                    return new FunctionCallResult(
                            functionCall.getName(),
                            parseArguments(functionCall.getArguments())
                    );
                });
    }

    private Map<String, Object> parseArguments(String argumentsJson) {
        try {
            return objectMapper.readValue(argumentsJson, new TypeReference<Map<String, Object>>() {});
        } catch (JsonProcessingException e) {
            throw new OpenAIException("Failed to parse function arguments", e);
        }
    }

    @Data
    @AllArgsConstructor
    private static class FunctionCallRequest {
        private String model;
        private List<Message> messages;
        private List<FunctionSpec> functions;
        private double temperature;
    }

    @Data
    private static class FunctionCallResponse {
        private List<Choice> choices;
        
        @Data
        private static class Choice {
            private Message message;
        }
        
        @Data
        private static class Message {
            private FunctionCall function_call;
        }
        
        @Data
        private static class FunctionCall {
            private String name;
            private String arguments;
        }
    }
    
    @Data
    @AllArgsConstructor
    public static class FunctionSpec {
        private String name;
        private String description;
        private Map<String, Object> parameters;
    }
    
    @Data
    @AllArgsConstructor
    public static class FunctionCallResult {
        private String functionName;
        private Map<String, Object> arguments;
    }
}

五、安全与优化

5.1 敏感信息过滤

public class ContentFilter {

    private static final List<String> SENSITIVE_WORDS = List.of("password", "secret", "api_key");
    private static final Pattern CREDIT_CARD_PATTERN = Pattern.compile("\\b(?:\\d[ -]*?){13,16}\\b");

    public String filterSensitiveInfo(String text) {
        // 过滤敏感词
        for (String word : SENSITIVE_WORDS) {
            text = text.replaceAll("(?i)" + word, "***");
        }
        
        // 过滤信用卡号
        Matcher matcher = CREDIT_CARD_PATTERN.matcher(text);
        return matcher.replaceAll("[CARD_FILTERED]");
    }
}

5.2 缓存优化

@Configuration
@EnableCaching
public class CacheConfig {

    @Bean
    public CacheManager cacheManager() {
        return new ConcurrentMapCacheManager("openaiResponses");
    }
}

@Service
public class CachedOpenAIService {

    private final OpenAIService openAIService;
    private final CacheManager cacheManager;
    private final ContentFilter contentFilter;

    public CachedOpenAIService(OpenAIService openAIService, CacheManager cacheManager, ContentFilter contentFilter) {
        this.openAIService = openAIService;
        this.cacheManager = cacheManager;
        this.contentFilter = contentFilter;
    }

    @Cacheable(value = "openaiResponses", key = "#prompt.hashCode()")
    public Mono<String> getCachedResponse(String prompt, String model) {
        String filteredPrompt = contentFilter.filterSensitiveInfo(prompt);
        return openAIService.chatCompletion(filteredPrompt, model);
    }
}

5.3 限流策略

@Bean
public RateLimiterRegistry rateLimiterRegistry() {
    return RateLimiterRegistry.of(
        RateLimiterConfig.custom()
            .limitForPeriod(100) // 每分钟100次
            .limitRefreshPeriod(Duration.ofMinutes(1))
            .timeoutDuration(Duration.ofSeconds(5))
            .build()
    );
}

@Aspect
@Component
public class RateLimiterAspect {

    private final RateLimiter rateLimiter;

    public RateLimiterAspect(RateLimiterRegistry registry) {
        this.rateLimiter = registry.rateLimiter("openaiApi");
    }

    @Around("@annotation(com.example.annotation.RateLimited)")
    public Object rateLimit(ProceedingJoinPoint joinPoint) throws Throwable {
        if (rateLimiter.acquirePermission()) {
            return joinPoint.proceed();
        } else {
            throw new RateLimitExceededException("API rate limit exceeded");
        }
    }
}

六、完整控制器示例

@RestController
@RequestMapping("/api/openai")
public class OpenAIController {

    private final ChatService chatService;
    private final ImageService imageService;
    private final ConversationService conversationService;

    public OpenAIController(ChatService chatService, ImageService imageService, ConversationService conversationService) {
        this.chatService = chatService;
        this.imageService = imageService;
        this.conversationService = conversationService;
    }

    @PostMapping("/chat")
    public Mono<ResponseEntity<String>> chat(@RequestBody ChatRequest request) {
        return chatService.chatCompletion(request.getPrompt(), request.getModel())
                .map(response -> ResponseEntity.ok(response));
    }

    @PostMapping("/image")
    public Mono<ResponseEntity<List<String>>> generateImage(@RequestBody ImageRequest request) {
        return imageService.generateImage(request.getPrompt(), request.getN(), request.getSize())
                .map(ResponseEntity::ok);
    }

    @PostMapping("/conversation/{sessionId}")
    public Mono<ResponseEntity<String>> continueConversation(
            @PathVariable String sessionId,
            @RequestBody UserMessage request) {
        return conversationService.continueConversation(sessionId, request.getMessage())
                .map(ResponseEntity::ok);
    }

    @Data
    private static class ChatRequest {
        private String prompt;
        private String model = "gpt-4";
    }

    @Data
    private static class ImageRequest {
        private String prompt;
        private int n = 1;
        private String size = "1024x1024";
    }

    @Data
    private static class UserMessage {
        private String message;
    }
}

七、错误处理与日志

7.1 全局异常处理

@RestControllerAdvice
public class GlobalExceptionHandler {

    @ExceptionHandler(OpenAIException.class)
    public ResponseEntity<ErrorResponse> handleOpenAIException(OpenAIException ex) {
        return ResponseEntity.status(HttpStatus.BAD_REQUEST)
                .body(new ErrorResponse("OPENAI_ERROR", ex.getMessage()));
    }

    @ExceptionHandler(RateLimitExceededException.class)
    public ResponseEntity<ErrorResponse> handleRateLimitException(RateLimitExceededException ex) {
        return ResponseEntity.status(HttpStatus.TOO_MANY_REQUESTS)
                .body(new ErrorResponse("RATE_LIMIT_EXCEEDED", ex.getMessage()));
    }

    @ExceptionHandler(Exception.class)
    public ResponseEntity<ErrorResponse> handleGeneralException(Exception ex) {
        return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(new ErrorResponse("INTERNAL_ERROR", "An unexpected error occurred"));
    }

    @Data
    @AllArgsConstructor
    private static class ErrorResponse {
        private String code;
        private String message;
    }
}

7.2 审计日志

@Aspect
@Component
public class APIAuditAspect {

    @Autowired
    private AuditLogRepository auditLogRepository;

    @AfterReturning(pointcut = "execution(* com.example.controller.OpenAIController.*(..))", returning = "result")
    public void logSuccess(JoinPoint joinPoint, Object result) {
        Object[] args = joinPoint.getArgs();
        String methodName = joinPoint.getSignature().getName();
        
        AuditLog log = new AuditLog();
        log.setMethod(methodName);
        log.setRequest(serializeRequest(args));
        log.setResponse(serializeResponse(result));
        log.setStatus("SUCCESS");
        log.setTimestamp(LocalDateTime.now());
        
        auditLogRepository.save(log);
    }

    @AfterThrowing(pointcut = "execution(* com.example.controller.OpenAIController.*(..))", throwing = "ex")
    public void logError(JoinPoint joinPoint, Throwable ex) {
        Object[] args = joinPoint.getArgs();
        String methodName = joinPoint.getSignature().getName();
        
        AuditLog log = new AuditLog();
        log.setMethod(methodName);
        log.setRequest(serializeRequest(args));
        log.setResponse(ex.getMessage());
        log.setStatus("ERROR");
        log.setTimestamp(LocalDateTime.now());
        
        auditLogRepository.save(log);
    }

    private String serializeRequest(Object[] args) {
        try {
            return new ObjectMapper().writeValueAsString(args);
        } catch (JsonProcessingException e) {
            return "Serialization error";
        }
    }

    private String serializeResponse(Object response) {
        try {
            return new ObjectMapper().writeValueAsString(response);
        } catch (JsonProcessingException e) {
            return "Serialization error";
        }
    }
}

八、部署配置

8.1 application.yml

openai:
  api:
    key: ${OPENAI_API_KEY}
    url: https://api.openai.com/v1
    timeout: 30s

spring:
  cache:
    type: caffeine
  redis:
    host: localhost
    port: 6379

logging:
  level:
    root: INFO
    com.example: DEBUG

8.2 Dockerfile

FROM eclipse-temurin:17-jdk-alpine
VOLUME /tmp
ARG JAR_FILE=target/*.jar
COPY ${JAR_FILE} app.jar
ENTRYPOINT ["java","-Djava.security.egd=file:/dev/./urandom","-jar","/app.jar"]

九、使用示例

9.1 聊天请求

curl -X POST http://localhost:8080/api/openai/chat \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "解释量子计算的基本原理",
    "model": "gpt-4"
  }'

9.2 图像生成

curl -X POST http://localhost:8080/api/openai/image \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "未来城市景观,赛博朋克风格",
    "n": 2,
    "size": "1024x1024"
  }'

9.3 连续对话

curl -X POST http://localhost:8080/api/openai/conversation/session123 \
  -H "Content-Type: application/json" \
  -d '{
    "message": "上一个问题中提到的量子比特是什么?"
  }'

十、最佳实践建议

  1. 密钥管理:
    • 使用环境变量或密钥管理服务存储API密钥
    • 避免在代码库中硬编码敏感信息
  2. 成本控制:
    • 设置API使用配额
    • 监控OpenAI API使用情况
    • 使用缓存减少重复请求
  3. 性能优化:
    • 设置合理的超时时间
    • 使用异步非阻塞调用
    • 批量处理请求
  4. 安全合规:
    • 实现内容过滤机制
    • 遵守OpenAI使用政策
    • 记录审计日志
  5. 错误处理:
    • 实现重试机制
    • 优雅降级处理
    • 监控API错误率

通过此万能模板,可快速集成OpenAI各种API功能,同时确保系统的稳定性、安全性和可扩展性。


网站公告

今日签到

点亮在社区的每一天
去签到