123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- package com.webchat.aigc.llm;
- import com.webchat.aigc.service.BotService;
- import com.webchat.common.constants.ConnectConstants;
- import com.webchat.common.enums.AiFunctionEnum;
- import com.webchat.common.enums.PromptTemplateEnum;
- import com.webchat.common.helper.SseEmitterHelper;
- import com.webchat.common.util.JsonUtil;
- import com.webchat.domain.dto.bot.BotDTO;
- import com.webchat.domain.vo.llm.FunctionCallResponse;
- import com.webchat.domain.vo.request.mess.ChatMessageRequestVO;
- import lombok.extern.slf4j.Slf4j;
- import org.apache.commons.collections.CollectionUtils;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.stereotype.Service;
- import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.List;
- import java.util.Map;
- import java.util.stream.Collectors;
- @Slf4j
- @Service
- public class AiBotChatService {
- @Autowired
- private AiFunctionCallService aiFunctionCallService;
- @Autowired
- private AiGenImageService aiGenImageService;
- @Autowired
- private AiBotPluginService aiBotPluginService;
- @Autowired
- private GPTChatService gptChatService;
- @Autowired
- private BotService botService;
- @Autowired
- private AiBotQAService aiBotQAService;
- private String getSSEBizCode() {
- return ConnectConstants.ConnectBiz.getBizCode(ConnectConstants.ClientEnum.PC,
- ConnectConstants.AppEnum.WEB,
- ConnectConstants.BizEnum.CHAT);
- }
- /**
- * 我的ai助手对话功能实现
- *
- * @param messageJson
- */
- public void chat(String messageJson) {
- ChatMessageRequestVO chatMessage = JsonUtil.fromJson(messageJson, ChatMessageRequestVO.class);
- // 消息发送人
- String senderId = chatMessage.getSenderId();
- // 用户输入消息
- String message = chatMessage.getMessage();
- String bizCode = getSSEBizCode();
- SseEmitter sseEmitter = SseEmitterHelper.get(bizCode, senderId);
- if (sseEmitter == null) {
- // 链接对象为空,说明当前用户sse链接不在当前节点,直接return,由集群其他节点处理任务
- return;
- }
- /************** 当前用户的 sse 链接在当前阶段,由当前节点完成AiBot对话服务处理 **************/
- /**************************************** 意图识别 ************************************/
- SseEmitterHelper.send(this.getSSEBizCode(), senderId, "意图识别中...");
- FunctionCallResponse functionCallResponse = this.getFunction(senderId, message);
- if (functionCallResponse == null) {
- SseEmitterHelper.send(this.getSSEBizCode(), senderId, "意图识别失败,请重试~");
- return;
- }
- String function = functionCallResponse.getFunction();
- AiFunctionEnum aiFunctionEnum = AiFunctionEnum.getFunction(function);
- String funcName = "未识别到意图";
- if (aiFunctionEnum == null) {
- // 查询平台查询意图信息
- BotDTO botDTO = botService.getBotPluginFromCache(function);
- if (botDTO == null) {
- SseEmitterHelper.send(this.getSSEBizCode(), senderId, "意图识别失败,请重试~");
- return;
- }
- funcName = botDTO.getName();
- } else {
- funcName = aiFunctionEnum.getFunctionName();
- }
- String functionInfo = "意图识别:"+ funcName;
- SseEmitterHelper.send(this.getSSEBizCode(), senderId, functionInfo);
- String aiInput = functionCallResponse.getPrompt();
- /****************************** 意图处理 **********************************/
- if (AiFunctionEnum.IMAGE.name().equals(function)) {
- /**
- * 通用文生图
- */
- aiGenImageService.doGenerate(aiInput, this.getSSEBizCode(), senderId);
- } else if (AiFunctionEnum.CHAT.name().equals(function)) {
- /**
- * 通用对话
- */
- aiBotQAService.chat(sseEmitter, chatMessage);
- } else {
- /**
- * 插件类
- */
- aiBotPluginService.doChat(chatMessage, bizCode, function);
- }
- }
- private FunctionCallResponse getFunction(String senderId, String message) {
- FunctionCallResponse functionCallResponse = null;
- try {
- List<String> pluginList = new ArrayList<>();
- List<BotDTO> botDTOList = botService.getPublishBotDetailFromCache();
- if (CollectionUtils.isNotEmpty(botDTOList)) {
- pluginList = botDTOList.stream().map(bot ->
- bot.getCode().concat(":").concat(bot.getDescription()))
- .collect(Collectors.toList());
- } else {
- pluginList.add("");
- }
- Map<String, Object> vars = new HashMap<>();
- vars.put("input", message);
- vars.put("pluginFuncList", pluginList);
- functionCallResponse =
- aiFunctionCallService.getFunction(vars, PromptTemplateEnum.AIBOT_FC);
- log.info("意图识别结果 =====> input: {}, response:{}",
- message, JsonUtil.toJsonString(functionCallResponse));
- } catch (Exception e) {
- log.error("意图识别异常 =====> input: {}",message, e);
- SseEmitterHelper.send(this.getSSEBizCode(), senderId, "意图识别失败,稍后重试~");
- }
- return functionCallResponse;
- }
- }
|