SpringBoot中OKHttp和压缩文件的使用

发布于:2024-10-17 ⋅ 阅读:(8) ⋅ 点赞:(0)

OKHttp和压缩文件实战

一、发起请求处理

import okhttp3.*;
import org.junit.jupiter.api.*;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.List;
import java.util.ArrayList;
import java.util.stream.Collectors;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.Map;

public class ApiServiceCaller {
    private static final ExecutorService executor = Executors.newFixedThreadPool(10, runnable -> {
        Thread thread = new Thread(runnable);
        thread.setName("ApiServiceCaller-Thread");
        thread.setDaemon(true);
        return thread;
    });
    private static final Logger logger = Logger.getLogger(ApiServiceCaller.class.getName());
    private static final OkHttpClient client = new OkHttpClient.Builder()
            .connectTimeout(5, TimeUnit.SECONDS)
            .readTimeout(5, TimeUnit.SECONDS)
            .connectionPool(new ConnectionPool(10, 5, TimeUnit.MINUTES))
            .retryOnConnectionFailure(true)
            .build();

    // 异步调用外部系统 API 的方法
    public CompletableFuture<String> callExternalApi(String url, Map<String, String> params, String method) {
        return CompletableFuture.supplyAsync(() -> {
            try {
                Request request = buildRequest(url, params, method);
                return executeRequest(request);
            } catch (Exception e) {
                logger.log(Level.SEVERE, "构建请求或执行请求时出错", e);
                throw new RuntimeException("调用 API 时出错: " + url, e);
            }
        }, executor);
    }

    // 构建 GET 请求
    private Request buildGetRequest(String url, Map<String, String> params) {
        HttpUrl.Builder httpBuilder = HttpUrl.parse(url).newBuilder();
        if (params != null && !params.isEmpty()) {
            params.forEach(httpBuilder::addQueryParameter);
        }
        return new Request.Builder().url(httpBuilder.build()).get().build();
    }

    // 构建 POST 请求
    private Request buildPostRequest(String url, Map<String, String> params) throws IOException {
        RequestBody body = RequestBody.create(
                MediaType.parse("application/json"),
                new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(params)
        );
        return new Request.Builder().url(url).post(body).build();
    }

    // 通用请求构建方法
    private Request buildRequest(String url, Map<String, String> params, String method) throws IOException {
        if ("GET".equalsIgnoreCase(method)) {
            return buildGetRequest(url, params);
        } else if ("POST".equalsIgnoreCase(method)) {
            return buildPostRequest(url, params);
        } else {
            throw new IllegalArgumentException("不支持的方法: " + method);
        }
    }

    // 执行请求并处理响应
    private String executeRequest(Request request) throws IOException {
        try (Response response = client.newCall(request).execute()) {
            if (response.isSuccessful() && response.body() != null) {
                String responseBody = response.body().string();
                logger.info("收到响应: " + responseBody);
                return responseBody;
            } else {
                logger.warning("收到非正常响应码: " + response.code());
                throw new RuntimeException("调用 API 失败,响应码: " + response.code());
            }
        }
    }

    // 处理多个不同 URL 和参数的 API 调用的方法
    public List<CompletableFuture<String>> callMultipleApis(List<ApiRequest> apiRequests) {
        logger.info("正在调用多个 API...");
        return apiRequests.stream()
                .map(request -> callExternalApi(request.getUrl(), request.getParams(), request.getMethod()))
                .collect(Collectors.toList());
    }

    // 高效处理 CompletableFuture 结果的方法
    public void processApiResponses(List<CompletableFuture<String>> futures) {
        CompletableFuture<Void> allOf = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
        allOf.thenAccept(v -> futures.forEach(future -> {
            future.handle((response, throwable) -> {
                if (throwable != null) {
                    logger.log(Level.SEVERE, "处理 future 出错", throwable);
                    System.err.println("处理 future 出错: " + throwable.getMessage());
                } else {
                    logger.info("处理响应: " + response);
                    System.out.println(response);
                }
                return null;
            });
        }));
    }

    // 主函数,调用 API
    public static void main(String[] args) {
        ApiServiceCaller apiServiceCaller = new ApiServiceCaller();
        List<ApiRequest> apiRequests = new ArrayList<>();
        apiRequests.add(new ApiRequest("http://example.com/api1", Map.of("param1", "value1"), "GET"));
        apiRequests.add(new ApiRequest("http://example.com/api2", Map.of("key", "value"), "POST"));
        apiRequests.add(new ApiRequest("http://example.com/api3", Map.of("param3", "value3"), "GET"));

        logger.info("开始调用 API...");
        List<CompletableFuture<String>> apiCalls = apiServiceCaller.callMultipleApis(apiRequests);
        apiServiceCaller.processApiResponses(apiCalls);
    }

    // ApiServiceCaller 的单元测试
    public static class ApiServiceCallerTest {

        @Test
        public void testCallExternalApi_getRequest() {
            ApiServiceCaller caller = new ApiServiceCaller();
            CompletableFuture<String> responseFuture = caller.callExternalApi("http://example.com/api1", Map.of("param", "value"), "GET");
            Assertions.assertDoesNotThrow(() -> {
                String response = responseFuture.get(10, TimeUnit.SECONDS);
                Assertions.assertNotNull(response);
            });
        }

        @Test
        public void testCallExternalApi_postRequest() {
            ApiServiceCaller caller = new ApiServiceCaller();
            CompletableFuture<String> responseFuture = caller.callExternalApi("http://example.com/api1", Map.of("key", "value"), "POST");
            Assertions.assertDoesNotThrow(() -> {
                String response = responseFuture.get(10, TimeUnit.SECONDS);
                Assertions.assertNotNull(response);
            });
        }

        @Test
        public void testCallMultipleApis() {
            ApiServiceCaller caller = new ApiServiceCaller();
            List<ApiRequest> apiRequests = new ArrayList<>();
            apiRequests.add(new ApiRequest("http://example.com/api1", Map.of("param1", "value1"), "GET"));
            apiRequests.add(new ApiRequest("http://example.com/api2", Map.of("key", "value"), "POST"));
            List<CompletableFuture<String>> responseFutures = caller.callMultipleApis(apiRequests);
            Assertions.assertEquals(2, responseFutures.size());
            responseFutures.forEach(future -> Assertions.assertDoesNotThrow(() -> {
                String response = future.get(10, TimeUnit.SECONDS);
                Assertions.assertNotNull(response);
            }));
        }
    }

    // 用于保存 API 请求详情的类
    public static class ApiRequest {
        private final String url;
        private final Map<String, String> params;
        private final String method;

        public ApiRequest(String url, Map<String, String> params, String method) {
            this.url = url;
            this.params = params;
            this.method = method;
        }

        public String getUrl() {
            return url;
        }

        public Map<String, String> getParams() {
            return params;
        }

        public String getMethod() {
            return method;
        }
    }
}

// 确保执行器的优雅关闭
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
    try {
        logger.info("正在关闭执行器...");
        executor.shutdown();
        if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
            logger.warning("执行器未在指定时间内终止。");
            executor.shutdownNow();
        }
    } catch (InterruptedException e) {
        logger.log(Level.SEVERE, "关闭过程中断", e);
        executor.shutdownNow();
    }
}));

二、压缩文件

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;

import java.io.*;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import java.util.concurrent.TimeUnit;

public class S3DownloadAndCompress {

    private final AmazonS3 s3Client;
    private final ExecutorService executorService;

    public S3DownloadAndCompress(int threadPoolSize) {
        System.out.println("初始化 S3 客户端和执行器服务...");
        this.s3Client = AmazonS3ClientBuilder.standard().build();
        this.executorService = Executors.newFixedThreadPool(threadPoolSize);
    }

    public ByteArrayOutputStream getCompressedFileStream(List<String> fileKeys, String bucketName) {
        System.out.println("开始下载和压缩过程...");
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        try (ZipOutputStream zipOut = new ZipOutputStream(baos)) {
            List<CompletableFuture<Void>> futures = fileKeys.stream()
                    .map(fileKey -> CompletableFuture.runAsync(() -> {
                        System.out.println("开始下载和压缩文件: " + fileKey);
                        downloadAndCompressFile(s3Client, bucketName, fileKey, zipOut);
                        System.out.println("完成下载和压缩文件: " + fileKey);
                    }, executorService))
                    .collect(Collectors.toList());

            CompletableFuture<Void> allDownloads = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
            allDownloads.join();
            System.out.println("所有文件已成功下载和压缩。");
        } catch (IOException e) {
            System.err.println("下载和压缩过程中出错: " + e.getMessage());
            e.printStackTrace();
        } finally {
            System.out.println("关闭执行器服务...");
            executorService.shutdown();
            try {
                if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
                    System.out.println("执行器服务未能在60秒内终止,正在强制关闭...");
                    executorService.shutdownNow();
                }
            } catch (InterruptedException e) {
                System.out.println("等待执行器服务终止时被中断,强制关闭...");
                executorService.shutdownNow();
            }
        }
        if (baos.size() == 0) {
            System.out.println("压缩文件流为空。");
        }
        return baos;
    }

    public void saveCompressedFileToPath(ByteArrayOutputStream compressedStream, String targetPath) {
        if (compressedStream == null || compressedStream.size() == 0) {
            System.out.println("压缩文件流为空,无法保存。");
            return;
        }
        try (FileOutputStream fos = new FileOutputStream(targetPath)) {
            compressedStream.writeTo(fos);
            System.out.println("压缩文件已保存到: " + targetPath);
        } catch (IOException e) {
            System.err.println("保存压缩文件时出错: " + e.getMessage());
            e.printStackTrace();
        }
    }

    private void downloadAndCompressFile(AmazonS3 s3Client, String bucketName, String fileKey, ZipOutputStream zipOut) {
        synchronized (zipOut) {
            try (S3Object s3Object = s3Client.getObject(bucketName, fileKey);
                 S3ObjectInputStream s3is = s3Object.getObjectContent()) {
                System.out.println("从桶中下载文件: " + fileKey + " 桶名称: " + bucketName);
                ZipEntry zipEntry = new ZipEntry(fileKey);
                zipOut.putNextEntry(zipEntry);

                byte[] buffer = new byte[4096];
                int length;
                while ((length = s3is.read(buffer)) >= 0) {
                    zipOut.write(buffer, 0, length);
                }
                zipOut.closeEntry();
                System.out.println("文件 " + fileKey + " 已添加到 zip 中。");
            } catch (IOException e) {
                System.err.println("下载或压缩文件时出错: " + fileKey + " - " + e.getMessage());
                e.printStackTrace();
            }
        }
    }

    public static void main(String[] args) {
        System.out.println("启动 S3DownloadAndCompress...");
        int threadPoolSize = 10; // 这个可以根据需要进行配置
        S3DownloadAndCompress downloader = new S3DownloadAndCompress(threadPoolSize);
        List<String> fileKeys = List.of("file1.txt", "file2.txt", "file3.txt");
        String bucketName = "your-bucket-name";
        String targetPath = "compressed_files.zip";

        ByteArrayOutputStream compressedStream = downloader.getCompressedFileStream(fileKeys, bucketName);
        downloader.saveCompressedFileToPath(compressedStream, targetPath);
        System.out.println("S3DownloadAndCompress 完成。");
    }
}
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.model.GeneratePresignedUrlRequest;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import com.amazonaws.services.s3.transfer.TransferManager;
import com.amazonaws.services.s3.transfer.TransferManagerBuilder;
import com.amazonaws.services.s3.transfer.Download;
import com.amazonaws.HttpMethod;

import java.io.*;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.Date;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import java.util.concurrent.TimeUnit;

public class S3DownloadAndCompress {

    private final AmazonS3 s3Client;
    private final ExecutorService executorService;
    private final TransferManager transferManager;
    private final String defaultFileName = "default_filename.txt";

    // 初始化 Amazon S3 客户端和线程池
    public S3DownloadAndCompress(int threadPoolSize) {
        System.out.println("初始化 S3 客户端和执行器服务...");
        this.s3Client = AmazonS3ClientBuilder.standard().build();
        this.executorService = Executors.newFixedThreadPool(threadPoolSize);
        this.transferManager = TransferManagerBuilder.standard().withS3Client(s3Client).build();
        System.out.println("S3 客户端和执行器服务初始化完成。");
    }

    // 获取文件列表,压缩成 Zip 文件,并返回压缩后的文件流
    public ByteArrayOutputStream getCompressedFileStream(List<String> fileKeys, String bucketName) {
        System.out.println("开始下载和压缩过程...");
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        try (ZipOutputStream zipOut = new ZipOutputStream(baos)) {
            List<CompletableFuture<Void>> futures = fileKeys.stream()
                    .map(fileKey -> CompletableFuture.runAsync(() -> {
                        System.out.println("开始下载和压缩文件: " + fileKey);
                        downloadAndCompressFile(bucketName, fileKey, zipOut);
                        System.out.println("完成下载和压缩文件: " + fileKey);
                    }, executorService))
                    .collect(Collectors.toList());

            CompletableFuture<Void> allDownloads = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
            allDownloads.join();
            System.out.println("所有文件已成功下载和压缩。");
        } catch (IOException e) {
            System.err.println("下载和压缩过程中出错: " + e.getMessage());
            e.printStackTrace();
        } finally {
            shutdownExecutorService();
        }
        System.out.println("压缩过程完成,返回压缩文件流。");
        return baos;
    }

    // 将压缩后的文件流保存到指定路径
    public void saveCompressedFileToPath(ByteArrayOutputStream compressedStream, String targetPath) {
        if (compressedStream == null || compressedStream.size() == 0) {
            throw new IllegalArgumentException("压缩文件流为空,无法保存。");
        }
        System.out.println("开始将压缩文件保存到路径: " + targetPath);
        try (FileOutputStream fos = new FileOutputStream(targetPath)) {
            compressedStream.writeTo(fos);
            System.out.println("压缩文件已保存到: " + targetPath);
        } catch (IOException e) {
            System.err.println("保存压缩文件时出错: " + e.getMessage());
            e.printStackTrace();
        }
    }

    // 从 S3 下载指定文件并保存到目标路径
    public void downloadFileToPath(String bucketName, String fileKey, String targetPath) {
        System.out.println("开始从 S3 下载文件: " + fileKey + " 到路径: " + targetPath);
        try {
            String resolvedFileKey = resolveFileKey(bucketName, fileKey);
            File targetFile = new File(targetPath);
            Download download = transferManager.download(bucketName, resolvedFileKey, targetFile);
            download.waitForCompletion();
            System.out.println("文件已成功下载到: " + targetPath);
        } catch (Exception e) {
            System.err.println("下载文件时出错: " + e.getMessage());
            e.printStackTrace();
        }
    }

    // 生成指定文件的临时访问链接
    public URL generatePresignedUrl(String bucketName, String fileKey, int expirationMinutes) {
        System.out.println("生成临时链接,文件: " + fileKey + " 有效期: " + expirationMinutes + " 分钟");
        try {
            String resolvedFileKey = resolveFileKey(bucketName, fileKey);
            Date expiration = new Date(System.currentTimeMillis() + expirationMinutes * 60 * 1000);
            GeneratePresignedUrlRequest request = new GeneratePresignedUrlRequest(bucketName, resolvedFileKey)
                    .withMethod(HttpMethod.GET)
                    .withExpiration(expiration);
            URL url = s3Client.generatePresignedUrl(request);
            System.out.println("生成的临时链接: " + url.toString());
            return url;
        } catch (Exception e) {
            System.err.println("生成临时链接时出错: " + e.getMessage());
            e.printStackTrace();
            return null;
        }
    }

    // 使用临时链接下载文件并保存到指定路径
    public void downloadFileFromPresignedUrl(URL presignedUrl, String targetPath) {
        System.out.println("使用临时链接下载文件到路径: " + targetPath);
        try (BufferedInputStream in = new BufferedInputStream(presignedUrl.openStream());
             FileOutputStream fileOutputStream = new FileOutputStream(targetPath)) {
            byte[] dataBuffer = new byte[8192];
            int bytesRead;
            while ((bytesRead = in.read(dataBuffer, 0, 8192)) != -1) {
                fileOutputStream.write(dataBuffer, 0, bytesRead);
            }
            System.out.println("文件已通过临时链接成功下载到: " + targetPath);
        } catch (IOException e) {
            System.err.println("通过临时链接下载文件时出错: " + e.getMessage());
            e.printStackTrace();
        }
    }

    // 使用临时链接获取文件的输入流
    public InputStream getFileStreamFromPresignedUrl(URL presignedUrl) {
        System.out.println("通过临时链接获取文件流: " + presignedUrl);
        try {
            HttpURLConnection connection = (HttpURLConnection) presignedUrl.openConnection();
            connection.setRequestMethod("GET");
            InputStream inputStream = connection.getInputStream();
            System.out.println("成功获取文件流。");
            return inputStream;
        } catch (IOException e) {
            System.err.println("通过临时链接获取文件流时出错: " + e.getMessage());
            e.printStackTrace();
            return null;
        }
    }

    // 解析文件键名,如果文件不存在则返回默认文件名
    private String resolveFileKey(String bucketName, String fileKey) {
        System.out.println("解析文件键名: " + fileKey);
        if (s3Client.doesObjectExist(bucketName, fileKey)) {
            System.out.println("文件存在: " + fileKey);
            return fileKey;
        } else {
            System.out.println("文件不存在,使用默认文件名: " + defaultFileName);
            return defaultFileName;
        }
    }

    // 从 S3 下载文件并将其压缩到 ZipOutputStream 中
    private void downloadAndCompressFile(String bucketName, String fileKey, ZipOutputStream zipOut) {
        System.out.println("从 S3 下载并压缩文件: " + fileKey);
        synchronized (zipOut) {
            try (S3Object s3Object = s3Client.getObject(bucketName, fileKey);
                 S3ObjectInputStream s3is = s3Object.getObjectContent()) {
                System.out.println("从桶中下载文件: " + fileKey + " 桶名称: " + bucketName);
                ZipEntry zipEntry = new ZipEntry(fileKey);
                zipOut.putNextEntry(zipEntry);

                byte[] buffer = new byte[8192];
                int length;
                while ((length = s3is.read(buffer)) >= 0) {
                    zipOut.write(buffer, 0, length);
                }
                zipOut.closeEntry();
                System.out.println("文件 " + fileKey + " 已添加到 zip 中。");
            } catch (IOException e) {
                System.err.println("下载或压缩文件时出错: " + fileKey + " - " + e.getMessage());
                e.printStackTrace();
            }
        }
    }

    // 关闭执行器服务
    private void shutdownExecutorService() {
        System.out.println("关闭执行器服务...");
        try {
            executorService.shutdown();
            if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
                System.out.println("执行器服务未能在60秒内终止,正在强制关闭...");
                executorService.shutdownNow();
                System.out.println("已调用 shutdownNow() 强制关闭执行器服务。");
            }
        } catch (InterruptedException e) {
            System.out.println("等待执行器服务终止时被中断,强制关闭...");
            executorService.shutdownNow();
            Thread.currentThread().interrupt();
        }
        System.out.println("执行器服务已关闭。");
    }

    public static void main(String[] args) {
        System.out.println("启动 S3DownloadAndCompress...");
        int threadPoolSize = 10; // 这个可以根据需要进行配置
        S3DownloadAndCompress downloader = new S3DownloadAndCompress(threadPoolSize);
        List<String> fileKeys = List.of("file1.txt", "file2.txt", "file3.txt");
        String bucketName = "your-bucket-name";
        String targetPath = "compressed_files.zip";

        // 下载并压缩文件并保存到目标路径
        System.out.println("开始下载并压缩文件...");
        downloader.downloadAndCompressFileToPath(fileKeys, bucketName, targetPath);
        System.out.println("下载并压缩文件完成。");

        // 直接下载到指定路径
        System.out.println("开始直接下载文件...");
        downloader.downloadFileToPath(bucketName, "file1.txt", "downloaded_file1.txt");
        System.out.println("直接下载文件完成。");

        // 生成临时链接
        System.out.println("开始生成临时链接...");
        URL presignedUrl = downloader.generatePresignedUrl(bucketName, "file2.txt", 60);
        if (presignedUrl != null) {
            System.out.println("访问临时链接: " + presignedUrl);
            // 通过临时链接下载到本地
            System.out.println("通过临时链接下载文件...");
            downloader.downloadFileFromPresignedUrl(presignedUrl, "downloaded_from_presigned_url.txt");
            System.out.println("通过临时链接下载文件完成。");
            // 获取文件流
            System.out.println("获取文件流...");
            InputStream fileStream = downloader.getFileStreamFromPresignedUrl(presignedUrl);
            if (fileStream != null) {
                System.out.println("成功获取文件流。");
            }
        }

        System.out.println("S3DownloadAndCompress 完成。");
    }
}

三、文件存储

1. 配置

# Bucket 1 Configuration
aws.buckets.bucket1.accessKey=accessKey1
aws.buckets.bucket1.secretKey=secretKey1
aws.buckets.bucket1.endpoint=http://endpoint1
aws.buckets.bucket1.region=us-east-1

# Bucket 2 Configuration
aws.buckets.bucket2.accessKey=accessKey2
aws.buckets.bucket2.secretKey=secretKey2
aws.buckets.bucket2.endpoint=http://endpoint2
aws.buckets.bucket2.region=us-west-1

2. 实体类

package com.example.s3config;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

@Component
@ConfigurationProperties
public class BucketConfig {

    private String accessKey;
    private String secretKey;
    private String endpoint;
    private String region;

    // Getters and setters
    public String getAccessKey() {
        return accessKey;
    }

    public void setAccessKey(String accessKey) {
        this.accessKey = accessKey;
    }

    public String getSecretKey() {
        return secretKey;
    }

    public void setSecretKey(String secretKey) {
        this.secretKey = secretKey;
    }

    public String getEndpoint() {
        return endpoint;
    }

    public void setEndpoint(String endpoint) {
        this.endpoint = endpoint;
    }

    public String getRegion() {
        return region;
    }

    public void setRegion(String region) {
        this.region = region;
    }
}

3. 配置类

package com.example.s3config;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

@Component
@ConfigurationProperties(prefix = "aws.buckets")
public class BucketsConfig {

    private static final Logger logger = LoggerFactory.getLogger(BucketsConfig.class);

    private Map<String, BucketConfig> bucketConfigs = new HashMap<>();

    public Map<String, BucketConfig> getBucketConfigs() {
        return bucketConfigs;
    }

    public void setBucketConfigs(Map<String, BucketConfig> bucketConfigs) {
        this.bucketConfigs = bucketConfigs;
        // Log to confirm if configurations are loaded correctly
        logger.info("Bucket configurations loaded: {}", bucketConfigs.keySet());
    }

    public BucketConfig getBucketConfig(String bucketName) {
        BucketConfig bucketConfig = bucketConfigs.get(bucketName);
        if (bucketConfig == null) {
            throw new IllegalArgumentException("Invalid bucket name: " + bucketName);
        }
        return bucketConfig;
    }
}

4. 初始化类

package com.example.s3config;

import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Component
public class AmazonS3Config {

    private static final Logger logger = LoggerFactory.getLogger(AmazonS3Config.class);

    private final BucketsConfig bucketsConfig;
    private final Map<String, AmazonS3> amazonS3ClientsCache = new ConcurrentHashMap<>();

    @Autowired
    public AmazonS3Config(BucketsConfig bucketsConfig) {
        this.bucketsConfig = bucketsConfig;
        logger.info("AmazonS3Config initialized with BucketsConfig");
    }

    public AmazonS3 getAmazonS3Client(String bucketName) {
        // Check if client is already in cache
        if (amazonS3ClientsCache.containsKey(bucketName)) {
            logger.debug("Returning cached AmazonS3 client for bucket: {}", bucketName);
            return amazonS3ClientsCache.get(bucketName);
        }

        // Get bucket configuration
        BucketConfig bucketConfig = bucketsConfig.getBucketConfig(bucketName);

        // Ensure all required configurations are present
        if (bucketConfig.getAccessKey() == null || bucketConfig.getSecretKey() == null ||
                bucketConfig.getEndpoint() == null || bucketConfig.getRegion() == null) {
            throw new IllegalArgumentException("Incomplete bucket configuration for: " + bucketName);
        }

        // Initialize AmazonS3 client
        BasicAWSCredentials awsCreds = new BasicAWSCredentials(bucketConfig.getAccessKey(), bucketConfig.getSecretKey());
        AmazonS3 amazonS3 = AmazonS3ClientBuilder.standard()
                .withCredentials(new AWSStaticCredentialsProvider(awsCreds))
                .withEndpointConfiguration(
                        new AmazonS3ClientBuilder.EndpointConfiguration(bucketConfig.getEndpoint(), bucketConfig.getRegion()))
                .withPathStyleAccessEnabled(true)
                .build();

        // Cache the client for future use
        amazonS3ClientsCache.put(bucketName, amazonS3);
        logger.info("AmazonS3 client created and cached for bucket: {}", bucketName);
        return amazonS3;
    }
}

5. 获取对象

package com.example.s3config;

import com.amazonaws.services.s3.AmazonS3;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;

@Service
public class S3Service {

    private static final Logger logger = LoggerFactory.getLogger(S3Service.class);

    private final AmazonS3Config amazonS3Config;

    @Autowired
    public S3Service(AmazonS3Config amazonS3Config) {
        this.amazonS3Config = amazonS3Config;
        logger.info("S3Service initialized with AmazonS3Config");
    }

    public void uploadFile(String bucketName, String key, File file) {
        AmazonS3 amazonS3 = amazonS3Config.getAmazonS3Client(bucketName);
        amazonS3.putObject(bucketName, key, file);
        logger.info("File uploaded to bucket: {}, key: {}", bucketName, key);
    }

    // Other operations
}

6. 主程序

package com.example.s3config;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.CommandLineRunner;
import org.springframework.context.annotation.Bean;

@SpringBootApplication
@EnableConfigurationProperties(BucketsConfig.class)
public class YourApplication {
    public static void main(String[] args) {
        SpringApplication.run(YourApplication.class, args);
    }

    @Bean
    CommandLineRunner validateBucketsConfig(BucketsConfig bucketsConfig) {
        return args -> {
            System.out.println("Validating bucket configurations: " + bucketsConfig.getBucketConfigs().keySet());
        };
    }
}

7. 测试类

package com.example.s3config;

import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.TestPropertySource;

import static org.junit.jupiter.api.Assertions.*;

@SpringBootTest
@TestPropertySource("classpath:application.properties")
public class BucketsConfigTest {

    @Autowired
    private BucketsConfig bucketsConfig;

    @Test
    public void testBucketsConfigLoaded() {
        assertNotNull(bucketsConfig, "BucketsConfig should not be null");
        assertFalse(bucketsConfig.getBucketConfigs().isEmpty(), "Bucket configurations should not be empty");
        assertTrue(bucketsConfig.getBucketConfigs().containsKey("bucket1"), "Bucket1 should be present in the configurations");
        assertTrue(bucketsConfig.getBucketConfigs().containsKey("bucket2"), "Bucket2 should be present in the configurations");
    }

    @Test
    public void testGetBucketConfig() {
        BucketConfig bucket1 = bucketsConfig.getBucketConfig("bucket1");
        assertNotNull(bucket1, "BucketConfig for bucket1 should not be null");
        assertEquals("accessKey1", bucket1.getAccessKey());
        assertEquals("secretKey1", bucket1.getSecretKey());
        assertEquals("http://endpoint1", bucket1.getEndpoint());
        assertEquals("us-east-1", bucket1.getRegion());
    }

    @Test
    public void testInvalidBucket() {
        Exception exception = assertThrows(IllegalArgumentException.class, () -> {
            bucketsConfig.getBucketConfig("invalidBucket");
        });
        assertEquals("Invalid bucket name: invalidBucket", exception.getMessage());
    }
}
package com.example.s3config;

import com.amazonaws.services.s3.AmazonS3;

import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.TestPropertySource;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

@SpringBootTest
@TestPropertySource("classpath:application.properties")
public class AmazonS3ConfigTest {

    @Autowired
    private AmazonS3Config amazonS3Config;

    @MockBean
    private BucketsConfig bucketsConfig;

    @Test
    public void testGetAmazonS3Client() {
        // Mock the BucketConfig
        BucketConfig bucketConfig = new BucketConfig();
        bucketConfig.setAccessKey("accessKey1");
        bucketConfig.setSecretKey("secretKey1");
        bucketConfig.setEndpoint("http://endpoint1");
        bucketConfig.setRegion("us-east-1");

        when(bucketsConfig.getBucketConfig("bucket1")).thenReturn(bucketConfig);

        AmazonS3 s3Client = amazonS3Config.getAmazonS3Client("bucket1");
        assertNotNull(s3Client, "AmazonS3 client should not be null");

        // Verify that the client is cached
        AmazonS3 cachedClient = amazonS3Config.getAmazonS3Client("bucket1");
        assertSame(s3Client, cachedClient, "Cached client should be the same instance");
    }

    @Test
    public void testGetAmazonS3ClientInvalidBucket() {
        when(bucketsConfig.getBucketConfig("invalidBucket"))
                .thenThrow(new IllegalArgumentException("Invalid bucket name: invalidBucket"));

        Exception exception = assertThrows(IllegalArgumentException.class, () -> {
            amazonS3Config.getAmazonS3Client("invalidBucket");
        });
        assertEquals("Invalid bucket name: invalidBucket", exception.getMessage());
    }
}
package com.example.s3config;

import com.amazonaws.services.s3.AmazonS3;

import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.springframework.boot.test.context.SpringBootTest;

import java.io.File;

import static org.mockito.Mockito.*;
import static org.junit.jupiter.api.Assertions.*;

@SpringBootTest
public class S3ServiceTest {

    @Mock
    private AmazonS3Config amazonS3Config;

    @Mock
    private AmazonS3 amazonS3;

    @InjectMocks
    private S3Service s3Service;

    @Test
    public void testUploadFile() {
        String bucketName = "bucket1";
        String key = "testFile.txt";
        File file = new File("testFile.txt");

        when(amazonS3Config.getAmazonS3Client(bucketName)).thenReturn(amazonS3);

        s3Service.uploadFile(bucketName, key, file);

        verify(amazonS3Config, times(1)).getAmazonS3Client(bucketName);
        verify(amazonS3, times(1)).putObject(bucketName, key, file);
    }

    @Test
    public void testUploadFileWithInvalidBucket() {
        String bucketName = "invalidBucket";
        String key = "testFile.txt";
        File file = new File("testFile.txt");

        when(amazonS3Config.getAmazonS3Client(bucketName))
                .thenThrow(new IllegalArgumentException("Invalid bucket name: " + bucketName));

        Exception exception = assertThrows(IllegalArgumentException.class, () -> {
            s3Service.uploadFile(bucketName, key, file);
        });
        assertEquals("Invalid bucket name: " + bucketName, exception.getMessage());
    }
}

8.依赖

确保在 pom.xml 中添加以下依赖:

<!-- AWS SDK -->
<dependency>
    <groupId>com.amazonaws</groupId>
    <artifactId>aws-java-sdk-s3</artifactId>
    <version>1.12.100</version>
</dependency>

<!-- Spring Boot Starter -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter</artifactId>
</dependency>

<!-- Spring Boot Configuration Processor -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-configuration-processor</artifactId>
    <optional>true</optional>
</dependency>

<!-- Testing -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-test</artifactId>
    <scope>test</scope>
</dependency>

<!-- Mockito -->
<dependency>
    <groupId>org.mockito</groupId>
    <artifactId>mockito-core</artifactId>
    <version>3.9.0</version>
    <scope>test</scope>
</dependency>