使用请求调用本地部署的stable-diffusion接口

发布于:2025-03-25 ⋅ 阅读:(48) ⋅ 点赞:(0)

stable-diffusion-webui项目地址
具体部署教程请去B站寻找或者直接使用整合包
这里直接编写工具类

public class StableDiffusionUtil {
    private static final String BASE_URL = "http://127.0.0.1:7860";
    private static final OkHttpClient CLIENT = new OkHttpClient();
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    /**
     * 获取可用模型列表
     *
     * @return 模型列表的 JSON 字符串
     * @throws IOException 如果请求失败
     */
    public static List<ModelInfo> listModels() throws IOException {
        Request request = new Request.Builder()
                .url(BASE_URL + "/sdapi/v1/sd-models")
                .get()
                .build();

        try (Response response = CLIENT.newCall(request).execute()) {
            if (response.isSuccessful()) {
                String jsonResponse = response.body().string();
                // 将 JSON 字符串解析为 ModelInfo 对象列表
                return OBJECT_MAPPER.readValue(jsonResponse, new TypeReference<List<ModelInfo>>() {});
            } else {
                throw new IOException("请求失败: " + response.code());
            }
        }
    }

    /**
     * 生成图片并保存
     *
     * @param prompt          生成图片的描述
     * @param negativePrompt  负面描述
     * @param steps           生成步骤
     * @param cfgScale        CFG 参数
     * @param width          图片宽度
     * @param height         图片高度
     * @param samplerIndex   采样器
     * @param modelCheckpoint 模型名称
     * @param outputFilePath  保存图片的文件路径
     * @throws IOException 如果请求失败或保存图片失败
     */
    public static byte[] generateImage(String prompt, String negativePrompt, int steps, int cfgScale,
                                       int width, int height, String samplerIndex, String modelCheckpoint) throws IOException {
        // 请求参数
        String json = "{"
                + "\"prompt\": \"" + prompt + "\","
                + "\"negative_prompt\": \"" + negativePrompt + "\","
                + "\"steps\": " + steps + ","
                + "\"cfg_scale\": " + cfgScale + ","
                + "\"width\": " + width + ","
                + "\"height\": " + height + ","
                + "\"sampler_index\": \"" + samplerIndex + "\","
                + "\"sd_model_checkpoint\": \"" + modelCheckpoint + "\""
                + "}";

        RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8"));
        Request request = new Request.Builder()
                .url(BASE_URL + "/sdapi/v1/txt2img")
                .post(body)
                .build();

        try (Response response = CLIENT.newCall(request).execute()) {
            if (response.isSuccessful()) {
                String jsonResponse = response.body().string();
                String imageBase64 = jsonResponse.split("\"images\":\\[\"")[1].split("\"")[0];

                // 解码 Base64 图片
                byte[] imageBytes = Base64.getDecoder().decode(imageBase64);
                ByteArrayInputStream bis = new ByteArrayInputStream(imageBytes);
                BufferedImage image = ImageIO.read(bis);

                // 将图片转换为二进制数组
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                ImageIO.write(image, "png", baos);
                return baos.toByteArray();
            } else {
                throw new IOException("请求失败: " + response.code());
            }
        }
    }

    /**
     * 切换模型
     *
     * @param modelCheckpoint 模型名称
     * @throws IOException 如果请求失败
     */
    public static void switchModel(String modelCheckpoint) throws IOException {
        // 更新模型配置
        String json = "{"
                + "\"sd_model_checkpoint\": \"" + modelCheckpoint + "\""
                + "}";

        RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8"));
        Request request = new Request.Builder()
                .url(BASE_URL + "/sdapi/v1/options")
                .post(body)
                .build();

        try (Response response = CLIENT.newCall(request).execute()) {
            if (response.isSuccessful()) {
                System.out.println("模型已切换到 " + modelCheckpoint);
            } else {
                throw new IOException("请求失败: " + response.code());
            }
        }
    }




}

调用工具类实现

@Autowired
    private SysOssService ossService;
    @Test
    void test8() throws Exception {

        String prompt = "A futuristic city with flying cars, highly detailed, 4k";
        String negativePrompt = "blurry, low quality, distorted, text";
        int steps = 30;
        int cfgScale = 7;
        int width = 512;
        int height = 512;
        String samplerIndex = "Euler a";
        String outputFilePath = "output.png";
        String modelInfo="anything-v5-PrtRE.safetensors [7f96a1a9ca]";
        // 调用方法生成图片
        byte[] imageBytes= StableDiffusionUtil.generateImage(prompt, negativePrompt, steps, cfgScale, width, height, samplerIndex, modelInfo);
        String fileName = "generated_image.png"; // 文件名
        String contentType = "image/png"; // 文件类型
        SysOssDTO ossDTO = ossService.upload(imageBytes, fileName, contentType);

        // 打印上传结果
        System.out.println("文件上传成功,访问路径: " + ossDTO.getUrl());
    }