@@ -11,6 +11,7 @@ import java.util.List; | |||
@NoArgsConstructor | |||
public class DmLandingLlmVo { | |||
private String category; | |||
private Boolean stream; | |||
private List<DmLlm> message = new ArrayList<>(); | |||
@Data | |||
@NoArgsConstructor | |||
@@ -0,0 +1,38 @@ | |||
package com.xueyi.nlt.api.nlt.feign; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.xueyi.nlt.api.nlt.domain.vo.DmLandingLlmUploadVo; | |||
import com.xueyi.nlt.api.nlt.domain.vo.DmLandingLlmVo; | |||
import com.xueyi.nlt.api.nlt.feign.factory.RemoteLandingLlmFallbackFactory; | |||
import feign.Response; | |||
import org.springframework.cloud.openfeign.FeignClient; | |||
import org.springframework.http.MediaType; | |||
import org.springframework.web.bind.annotation.*; | |||
import org.springframework.web.multipart.MultipartFile; | |||
import java.util.List; | |||
/** | |||
* 问答服务 | |||
* @Param man_id 机器人id | |||
* @Param question 问题 | |||
* @Param tenant_id 租户id | |||
* @author yrx | |||
*/ | |||
@FeignClient(url = "${notification.airport-llm.url}",name = "airport-llm", fallbackFactory = RemoteLandingLlmFallbackFactory.class) | |||
public interface RemoteAirportLagiLlmService { | |||
@PostMapping("/search/questionAnswer") | |||
@ResponseBody | |||
JSONObject query(@RequestBody DmLandingLlmVo vo); | |||
@PostMapping(value = "/search/questionAnswer",produces = MediaType.TEXT_EVENT_STREAM_VALUE) | |||
@ResponseBody | |||
Response queryStream(@RequestBody JSONObject vo); | |||
@PostMapping(value = "/search/questionAnswer",produces = MediaType.TEXT_EVENT_STREAM_VALUE) | |||
@ResponseBody | |||
String queryStreamStr(@RequestBody JSONObject vo); | |||
} |
@@ -104,6 +104,7 @@ public class ChatServerHandler extends SimpleChannelInboundHandler<TextWebSocket | |||
String devId = jsonObject.getString("devId"); | |||
Boolean llm = jsonObject.containsKey("llm")?jsonObject.getBoolean("llm"): true; | |||
Boolean integrityDetection = jsonObject.containsKey("integrityDetection")?jsonObject.getBoolean("integrityDetection"): true; | |||
String llmServer = jsonObject.containsKey("llmServer")?jsonObject.getString("llmServer"): "spark"; | |||
Long operatorId = jsonObject.getLong("operatorId"); | |||
// 获取到发送人的用户id | |||
String msg = jsonObject.getString("msg"); | |||
@@ -185,15 +186,34 @@ public class ChatServerHandler extends SimpleChannelInboundHandler<TextWebSocket | |||
} | |||
} | |||
JSONObject jo = new JSONObject(); | |||
jo.put("action","chat"); | |||
jo.put("motion","idle"); | |||
jo.put("traceId",""); | |||
jo.put("status",0); | |||
jo.put("tts","请稍等一下,我要查询一下功能。"); | |||
channel.writeAndFlush(new TextWebSocketFrame(jo.toJSONString())); | |||
INSTANCE.logService.record(jo,msg,enterpriseName,"大模型"); | |||
sendMsg(devId, msg); | |||
return; | |||
switch (llmServer) { | |||
case "airport": | |||
// 调用机场大模型 | |||
jo.put("action","chat"); | |||
jo.put("motion","idle"); | |||
jo.put("traceId",""); | |||
jo.put("status",0); | |||
jo.put("tts","请稍等一下,我要查询一下功能。"); | |||
channel.writeAndFlush(new TextWebSocketFrame(jo.toJSONString())); | |||
INSTANCE.logService.record(jo,msg,enterpriseName,"大模型"); | |||
INSTANCE.freeChatTemplate.handleStream(devId, msg.toString(),true); | |||
return; | |||
case "spark": | |||
default: | |||
// 调用星火大模型 | |||
jo = new JSONObject(); | |||
jo.put("action","chat"); | |||
jo.put("motion","idle"); | |||
jo.put("traceId",""); | |||
jo.put("status",0); | |||
jo.put("tts","请稍等一下,我要查询一下功能。"); | |||
channel.writeAndFlush(new TextWebSocketFrame(jo.toJSONString())); | |||
INSTANCE.logService.record(jo,msg,enterpriseName,"大模型"); | |||
sendMsg(devId, msg); | |||
return; | |||
} | |||
} | |||
//未触发大模型的结束case | |||
JSONObject jo = new JSONObject(); | |||
@@ -0,0 +1,136 @@ | |||
package com.xueyi.nlt.nlt.service.impl; | |||
import com.alibaba.fastjson2.JSON; | |||
import com.alibaba.fastjson2.JSONArray; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.baomidou.mybatisplus.core.toolkit.StringUtils; | |||
import com.xueyi.nlt.api.nlt.feign.RemoteAirportLagiLlmService; | |||
import com.xueyi.nlt.netty.server.config.ServerConfig; | |||
import com.xueyi.nlt.nlt.domain.LlmContent; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
import com.xueyi.nlt.nlt.domain.LlmResponse; | |||
import com.xueyi.nlt.nlt.service.ISysLlmService; | |||
import feign.Response; | |||
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.stereotype.Service; | |||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; | |||
import java.io.IOException; | |||
import java.io.InputStream; | |||
import java.util.List; | |||
import java.util.concurrent.TimeUnit; | |||
import java.util.stream.Collectors; | |||
@Service | |||
public class LagiServiceImpl implements ISysLlmService { | |||
@Autowired | |||
RemoteAirportLagiLlmService remoteAirportLagiLlmService; | |||
@Override | |||
public LlmResponse chat(LlmContext context, LlmParam param) { | |||
return new LlmResponse(); | |||
} | |||
@Override | |||
public LlmResponse stream(LlmContext context, LlmParam param) { | |||
List<String> contentArr = context.getContentList().stream().map(LlmContent::getContent).collect(Collectors.toList()); | |||
SseEmitter emitter = new SseEmitter(); | |||
emitter.onCompletion(new Runnable() { | |||
@Override | |||
public void run() { | |||
System.out.println("进入了onCompletion"); | |||
} | |||
}); | |||
emitter.onError((e) -> { | |||
// e.printStackTrace(); | |||
System.out.println("进入了onError"); | |||
}); | |||
JSONObject testJ = new JSONObject(); | |||
testJ.put("category","airport"); | |||
testJ.put("stream",true); | |||
JSONArray ja1 = new JSONArray(); | |||
JSONObject jo = new JSONObject(); | |||
jo.put("role","user"); | |||
jo.put("content",context.getContentList().get(context.getContentList().size() - 1).getContent()); | |||
ja1.add(jo); | |||
testJ.put("messages",ja1); | |||
ServerConfig.sessionMap.get(context.getDevId()).writeAndFlush(new TextWebSocketFrame("开始了")); | |||
new Thread(()->{ | |||
Response response = remoteAirportLagiLlmService.queryStream(testJ); | |||
Response.Body body = response.body(); | |||
InputStream fileInputStream = null; | |||
try { | |||
fileInputStream = body.asInputStream(); | |||
byte[] bytes = new byte[1024]; | |||
int len = 0; | |||
String buffer = ""; | |||
JSONObject result = new JSONObject(); | |||
StringBuilder sb = new StringBuilder(); | |||
int count = 1; | |||
while ((len = fileInputStream.read(bytes)) != -1) { | |||
sb.append(new String(bytes, 0, len, "utf-8")); | |||
String text = getSentences(sb); | |||
if (StringUtils.isBlank(text)) { | |||
continue; | |||
} | |||
if (text.contains("[DONE]")) { | |||
// 通知前端结束 | |||
result.put("action","chat"); | |||
result.put("motion","idle"); | |||
result.put("traceId",""); | |||
result.put("status",2); | |||
result.put("tts",text.substring(0,text.indexOf("[DONE]"))); | |||
ServerConfig.sessionMap.get(context.getDevId()).writeAndFlush(new TextWebSocketFrame(result.toJSONString())); | |||
break; | |||
} | |||
// 通知前端 | |||
result.put("action","chat"); | |||
result.put("motion","idle"); | |||
result.put("traceId",""); | |||
result.put("status",1); | |||
result.put("tts", text); | |||
ServerConfig.sessionMap.get(context.getDevId()).writeAndFlush(new TextWebSocketFrame(result.toJSONString())); | |||
System.out.println("第" + count++ + "次:" + text); | |||
} | |||
fileInputStream.close(); | |||
} catch (IOException e) { | |||
e.printStackTrace(); | |||
} | |||
}).start(); | |||
ServerConfig.currentTraceMap.put(context.getDevId(),context.getTraceId()); | |||
LlmResponse response = new LlmResponse(); | |||
return response; | |||
} | |||
private String getSentences(StringBuilder sb) { | |||
StringBuilder result = new StringBuilder(); | |||
System.out.println(sb.toString()); | |||
while (sb.indexOf("data:") != -1 && sb.indexOf("\n\n") != -1 ) { | |||
int start = sb.indexOf("data:"); | |||
int end = sb.indexOf("\n\n"); | |||
if (start + 5 <= end) { | |||
if (JSON.isValid(sb.substring(start + 5, end))) { | |||
JSONObject jo = JSON.parseObject(sb.substring(start + 5, end)); | |||
result.append(jo.get("text")); | |||
} else { | |||
result.append(sb.substring(start + 5, end)); | |||
} | |||
sb.delete(start, end+2); | |||
} | |||
} | |||
return result.toString(); | |||
} | |||
} |
@@ -3,6 +3,7 @@ package com.xueyi.nlt.nlt.template; | |||
import com.alibaba.druid.util.StringUtils; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.xueyi.common.core.context.SecurityContextHolder; | |||
import com.xueyi.common.core.utils.core.SpringUtils; | |||
import com.xueyi.common.core.web.result.AjaxResult; | |||
import com.xueyi.nlt.api.nlt.domain.vo.KnowledgeVo; | |||
import com.xueyi.nlt.api.nlt.feign.RemoteQAService; | |||
@@ -15,6 +16,7 @@ import com.xueyi.nlt.nlt.domain.dto.DmPromptDto; | |||
import com.xueyi.nlt.nlt.service.IDmHotspotService; | |||
import com.xueyi.nlt.nlt.service.IDmPromptService; | |||
import com.xueyi.nlt.nlt.service.ISysLlmService; | |||
import com.xueyi.nlt.nlt.service.impl.LagiServiceImpl; | |||
import com.yomahub.tlog.core.annotation.TLogAspect; | |||
import org.slf4j.Logger; | |||
import org.slf4j.LoggerFactory; | |||
@@ -155,6 +157,33 @@ public class FreeChatTemplate implements BaseTemplate{ | |||
return null; | |||
} | |||
public JSONObject handleStream(String dev, String content, boolean stream) { | |||
Long operatorId = TerminalSecurityContextHolder.getOperatorId(); | |||
String redisKey = "group:nlp:" + SecurityContextHolder.getLocalMap().get("enterprise_id") + ":" + operatorId; | |||
// 根据content内容调用模版并返回结果 | |||
List<String> context = new ArrayList<>(); | |||
context.add("user"); | |||
context.add(content); | |||
//使用stream去除context列表中所有字符串中的引号 | |||
context = context.stream().map(s -> s.replaceAll("\"", "")).collect(java.util.stream.Collectors.toList()); | |||
//webSocketClient.sendMsg(context); | |||
LlmContext llmContext = LlmContext.parse(context,true); | |||
llmContext.setDevId(dev); | |||
LlmParam param = new LlmParam(); | |||
ISysLlmService llmService = SpringUtils.getBean(LagiServiceImpl.class); | |||
LlmResponse response = llmService.stream(llmContext,param); | |||
log.info("llmContext:{}",llmContext); | |||
JSONObject resultJson = new JSONObject(); | |||
resultJson.put("tts","让我想一想。"); | |||
resultJson.put("motion","idle"); | |||
resultJson.put("status","0"); | |||
resultJson.put("action","chat"); | |||
return resultJson; | |||
} | |||
private String generatePrompts(String content){ | |||
String msg = content; | |||
if(msg.equals("")){ | |||