百度文心一言 java 支持流式输出,Springboot+ sse的demo

发布于:2024-05-15 ⋅ 阅读:(119) ⋅ 点赞:(0)

参考:GitHub - mmciel/wenxin-api-java: 百度文心一言Java库,支持问答和对话,支持流式输出和同步输出。提供SpringBoot调用样例。提供拓展能力。

1、依赖

<dependency>
<groupId>com.baidu.aip</groupId>
<artifactId>java-sdk</artifactId>
<version>4.16.18</version>
</dependency>

2、配置apikey和secretkey

3、主要使用的接口

4、返回的json格式 

3、WenxinEventSourceListener  事件监听器

和其他的接口不一样 需要 CompletionsResponse.data  封装下 ,不然前端页面需要兼容非json的格式

@Slf4j
public class WenxinEventSourceListener extends EventSourceListener {

    private long tokens;

    private SseEmitter sseEmitter;

    public WenxinEventSourceListener(SseEmitter sseEmitter) {
        this.sseEmitter = sseEmitter;
    }

    @Override
    public void onOpen(EventSource eventSource, Response response) {
        log.info("建立sse连接...");
    }

    @SneakyThrows
    @Override
    @JsonIgnoreProperties(ignoreUnknown = true)
    public void onEvent(EventSource eventSource, String id, String type, String data) {
        ChatResponse bean = JSONUtil.parseObj(data).toBean(ChatResponse.class);
        log.info("返回数据:{}", data);
        if (bean.getIs_end()) {
            log.info("返回数据结束了");
            sseEmitter.send(SseEmitter.event()
                    .id("[TOKENS]")
                    .data("<br/><br/>tokens:" + tokens())
                    .reconnectTime(3000));
            sseEmitter.send(SseEmitter.event()
                    .id("[DONE]")
                    .data("[DONE]")
                    .reconnectTime(3000));
            // 传输完成后自动关闭sse
            sseEmitter.complete();
            return;
        }
        log.info("OpenAI返回数据:{}", data);
        tokens += 1;
        if (data.equals("[DONE]")) {
            log.info("OpenAI返回数据结束了");
            sseEmitter.send(SseEmitter.event()
                    .id("[TOKENS]")
                    .data("<br/><br/>tokens:" + tokens())
                    .reconnectTime(3000));
            sseEmitter.send(SseEmitter.event()
                    .id("[DONE]")
                    .data("[DONE]")
                    .reconnectTime(3000));
            // 传输完成后自动关闭sse
            sseEmitter.complete();
            return;
        }

        CompletionsResponse completionResponse = new CompletionsResponse();
        CompletionsResponse.Data dataResult = new CompletionsResponse.Data();
        dataResult.setText(bean.getResult());

        completionResponse.setData(dataResult);
        try {
            sseEmitter.send(SseEmitter.event()
                    .id(bean.getId())
                    .data(completionResponse.getData())
                    .reconnectTime(3000));
        } catch (Exception e) {
            log.error("sse信息推送失败!");
            eventSource.cancel();
            e.printStackTrace();
        }
    }

    @Override
    public void onClosed(EventSource eventSource) {
        log.info("关闭sse连接...");
    }

    @SneakyThrows
    @Override
    public void onFailure(EventSource eventSource, Throwable t, Response response) {
        if(Objects.isNull(response)){
            log.error("sse连接异常:{}", t);
            eventSource.cancel();
            return;
        }
        ResponseBody body = response.body();
        if (Objects.nonNull(body)) {
            // 错误处理 {"error_code":110,"error_msg":"Access token invalid or no longer valid"},异常:{}
            log.error("sse连接异常data:{},异常:{}", body.string(), t);
        } else {
            log.error("sse连接异常data:{},异常:{}", response, t);
        }
        eventSource.cancel();
    }

    /**
     * tokens
     * @return
     */
    public long tokens() {
        return tokens;
    }
}

4、WenXinClient  流式主要看下 streamChat 方式,之前从千帆上找到流式例子 返回type是json的,所以之前自己手写的demo总报异常。

 public void streamChat(ChatBody chatBody, EventSourceListener eventSourceListener, ModelE modelE) {
        if (Objects.isNull(eventSourceListener)) {
            throw new WenXinException("参数异常:EventSourceListener不能为空");
        }
        chatBody.setStream(true);
        try {
            EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
            Request request = new Request.Builder().url(assembleUrl(modelE))
                    .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()),
                            new ObjectMapper().writeValueAsString(chatBody))).build();
            factory.newEventSource(request, eventSourceListener);
        } catch (Exception e) {
            log.error("请求参数解析异常:", e);
            e.printStackTrace();
        }
    }

private String assembleUrl(ModelE modelE) {
        accessToken = WenXinConfig.refreshAccessToken();
        return modelE.getApiHost() + "?access_token=" + accessToken;
    }

5、定义Sse的接口是实现方法

public interface SseService {
    /**
     * 创建SSE
     * @param uid
     * @return
     */
    SseEmitter createSse(String uid);

    /**
     * 关闭SSE
     * @param uid
     */
    void closeSse(String uid);

    /**
     * 客户端发送消息到服务端
     * @param uid
     * @param chatRequest
     */
    ChatResponse sseChat(String uid, ChatRequest chatRequest);
}
public class WenXinSseServiceImpl implements SseService {
    @Value("${chat.accessKeyId}")
    private String accessKeyId;
    @Value("${chat.accessKeySecret}")
    private String accessKeySecret;
    @Value("${chat.agentKey}")
    private String agentKey;
    @Value("${chat.appId}")
    private String appId;

    @Autowired
    WenXinClient wenXinClient;
    @Override
    public SseEmitter createSse(String uid) {
        //默认30秒超时,设置为0L则永不超时
        SseEmitter sseEmitter = new SseEmitter(0l);
        //完成后回调
        sseEmitter.onCompletion(() -> {
            log.info("[{}]结束连接...................", uid);
            LocalCache.CACHE.remove(uid);
        });
        //超时回调
        sseEmitter.onTimeout(() -> {
            log.info("[{}]连接超时...................", uid);
        });
        //异常回调
        sseEmitter.onError(
                throwable -> {
                    try {
                        log.info("[{}]连接异常,{}", uid, throwable.toString());
                        sseEmitter.send(SseEmitter.event()
                                .id(uid)
                                .name("发生异常!")
                                .data(Message.builder().content("发生异常请重试!").build())
                                .reconnectTime(3000));
                        LocalCache.CACHE.put(uid, sseEmitter);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
        );
        try {
            sseEmitter.send(SseEmitter.event().reconnectTime(5000));
        } catch (IOException e) {
            e.printStackTrace();
        }
        LocalCache.CACHE.put(uid, sseEmitter);
        log.info("[{}]创建sse连接成功!", uid);
        return sseEmitter;
    }

    @Override
    public void closeSse(String uid) {
        SseEmitter sse = (SseEmitter) LocalCache.CACHE.get(uid);
        if (sse != null) {
            sse.complete();
            //移除
            LocalCache.CACHE.remove(uid);
        }
    }

    @Override
    public ChatResponse sseChat(String uid, ChatRequest chatRequest) {

        if (StringUtils.isBlank(chatRequest.getMsg())) {
            log.error("参数异常,msg为null", uid);
            throw new BaseException("参数异常,msg不能为空~");
        }

        SseEmitter sseEmitter = (SseEmitter) LocalCache.CACHE.get(uid);

        if (sseEmitter == null) {
            log.info("聊天消息推送失败uid:[{}],没有创建连接,请重试。", uid);
            throw new BaseException("聊天消息推送失败uid:[{}],没有创建连接,请重试。~");
        }

        WenxinEventSourceListener openAIEventSourceListener = new WenxinEventSourceListener(sseEmitter);

        List<MessageItem> messages = new ArrayList<>();
        messages.add(MessageItem.builder().role(MessageItem.Role.USER).content(chatRequest.getMsg()).build());
        wenXinClient.streamChat(messages, openAIEventSourceListener, ModelE.ERNIE_Bot);


        LocalCache.CACHE.put("msg" + uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);

        ChatResponse response = new ChatResponse();
        response.setQuestionTokens(1);

        return response;
    }
}

6、主要的controller接口

/**
     * 创建sse连接
     *
     * @param headers
     * @return
     */
    @CrossOrigin
    @GetMapping("/createSse")
    public SseEmitter createConnect(@RequestHeader Map<String, String> headers) {
        String uid = getUid(headers);
        return sseService.createSse(uid);
    }

    /**
     * 聊天接口
     *
     * @param chatRequest
     * @param headers
     */
    @CrossOrigin
    @PostMapping("/chat")
    @ResponseBody
    public ChatResponse sseChat(@RequestBody ChatRequest chatRequest, @RequestHeader Map<String, String> headers, HttpServletResponse response) {
        String uid = getUid(headers);
        return sseService.sseChat(uid, chatRequest);
    }

    /**
     * 关闭连接
     *
     * @param headers
     */
    @CrossOrigin
    @GetMapping("/closeSse")
    public void closeConnect(@RequestHeader Map<String, String> headers) {
        String uid = getUid(headers);
        sseService.closeSse(uid);
    }

7、主要的页面代码

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>智能问答</title>
  <link rel="stylesheet" href="styles.css"> <!-- 引入外部CSS -->
  <script src="HZRecorder.js"></script>
  <script src="https://cdn.bootcdn.net/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
  <script src="js/markdown.min.js"></script>
  <script src="js/eventsource.min.js"></script>
  <script>

      function setText(text, uuid_str) {
        let content = document.getElementById(uuid_str);
        content.innerHTML = marked(text);
      }

      function uuid() {
        var s = [];
        var hexDigits = "0123456789abcdef";
        for (var i = 0; i < 36; i++) {
          s[i] = hexDigits.substr(Math.floor(Math.random() * 0x10), 1);
        }
        s[14] = "4"; // bits 12-15 of the time_hi_and_version field to 0010
        s[19] = hexDigits.substr((s[19] & 0x3) | 0x8, 1); // bits 6-7 of the clock_seq_hi_and_reserved to 01
        s[8] = s[13] = s[18] = s[23] = "-";

        var uuid = s.join("");
        console.log(uuid)
        return uuid;

      }



      window.onload = function () {
        /*let disconnectBtn = document.getElementById("disconnectSSE");*/
        let messageElement = document.getElementById("messageInput");
        let chat = document.getElementById("chat-messages");
        let sse;
        let uid = window.localStorage.getItem("uid");
        if (uid == null || uid == "" || uid == "null") {
          uid = uuid();
        }
        let text = "";
        let uuid_str;
        // 设置本地存储
        window.localStorage.setItem("uid", uid);

        // 发送消息按钮点击事件
        document.getElementById('sendTextButton').addEventListener('click', async function () {
          try {
            const userInput = document.getElementById('messageInput').value.trim();
            if (userInput) {
              await sseOneTurn(userInput)
              userInput.value = ''; // 清空输入框

            } else {
              alert('请输入文字消息!');
            }
          } catch (error) {
            alert('发送消息时发生错误: ' + error.message);
          }
        });

        // 回车事件
        messageElement.onkeydown = function () {
          if (window.event.keyCode === 13) {
            if (!messageElement.value) {
              return;
            }
            sseOneTurn(messageElement.value);
          }
        };

        function sseOneTurn(InputText) {
          uuid_str = uuid();
          //创建sse
          const eventSource = new EventSourcePolyfill("/createSse", {
            headers: {
              uid: uid,
            },
          });

          eventSource.onopen = (event) => {
            console.log("开始输出后端返回值");
            sse = event.target;
          };
          eventSource.onmessage = (event) => {
            debugger;
            if (event.lastEventId == "[TOKENS]") {
              text = text + event.data;
              setText(text, uuid_str);
              text = "";
              return;
            }
            if (event.data == "[DONE]") {
              text = "";
              if (sse) {
                sse.close();
              }
              return;
            }
            let json_data = JSON.parse(event.data);
            console.log(json_data);
            if (json_data.text == null || json_data.text == "null") {
              return;
            }
            text = text + json_data.text;
            setText(text, uuid_str);
          };
          eventSource.onerror = (event) => {
            console.log("onerror", event);
            alert("服务异常请重试并联系开发者!");
            if (event.readyState === EventSource.CLOSED) {
              console.log("connection is closed");
            } else {
              console.log("Error occured", event);
            }
            event.target.close();
          };
          eventSource.addEventListener("customEventName", (event) => {
            console.log("Message id is " + event.lastEventId);
          });
          eventSource.addEventListener("customEventName", (event) => {
            console.log("Message id is " + event.lastEventId);
          });
          $.ajax({
            type: "post",
            url: "/chat",
            data: JSON.stringify({
              msg: InputText,
            }),
            contentType: "application/json;charset=UTF-8",
            dataType: "json",
            headers: {
              uid: uid,
            },
            beforeSend: function (request) {},
            success: function (result) {
              //新增问题框
              debugger;
              chat.innerHTML +=
                      '<tr><td style="height: 30px;">' +
                      InputText +
                      "<br/><br/> tokens:" +
                      result.question_tokens +
                      "</td></tr>";
              InputText = null;
              //新增答案框
              chat.innerHTML +=
                      '<tr><td><article id="' +
                      uuid_str +
                      '" class="markdown-body"></article></td></tr>';
            },
            complete: function () {},
            error: function () {
              console.info("发送问题失败!");
            },
          });
        }

        /*disconnectBtn.onclick = function () {
          if (sse) {
            sse.close();
          }
        };*/
      };
    </script>
  </head>
<body>

<div class="chat-container">
  <div class="chat-header">
    <h1>智能问答</h1>
  </div>
  <div class="chat-messages" id="chat-messages">
    <!-- 聊天消息将会在这里显示 -->
  </div>
  <form class="message-form" onsubmit="return false;">
    <input type="text" id="messageInput" placeholder="输入消息..." autocomplete="off">
    <button type="button" id="sendTextButton">发送文字</button>
    <button type="button" id="recordAndUploadButton">按住录音</button>
    <progress id="uploadProgress" value="0" max="100" style="display:none;"></progress>
  </form>
</div>

</body>

</html>

最后的呈现效果如下:


网站公告

今日签到

点亮在社区的每一天
去签到