Browse Source

修改:

1、修改webscoket加锁逻辑
tags/B.2.6.4_20240106_base
10710 1 year ago
parent
commit
99f25eb571
3 changed files with 162 additions and 143 deletions
  1. +139
    -131
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/WebSocketClient.java
  2. +17
    -9
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/controller/DmIntentController.java
  3. +6
    -3
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/service/impl/SparkServiceImpl.java

+ 139
- 131
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/WebSocketClient.java View File

@@ -135,10 +135,10 @@ public class WebSocketClient extends WebSocketListener {
e.printStackTrace();
}
}
public void sendMsg(List<String> messages){
this.sendMsg(messages,false, null);
public WebSocketClient sendMsg(List<String> messages){
return this.sendMsg(messages,false, "null");
}
public void sendMsg(List<String> messages,boolean stream,String userId){
public WebSocketClient sendMsg(List<String> messages,boolean stream,String userId){
this.stream = stream;
this.curUserId = userId;
if (messages.size() / 2 > 0) {
@@ -164,9 +164,10 @@ public class WebSocketClient extends WebSocketListener {
ServerConfig.currentTraceMap.put(curUserId,wsc.traceId);
System.out.println("wocket客户端:" + wsc.hashCode());
webSocket = okHttpClient.newWebSocket(request,wsc);
return wsc;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
//鉴权url
@@ -283,148 +284,155 @@ public class WebSocketClient extends WebSocketListener {
return;
}
}
ResponseData responseData = json.fromJson(text,ResponseData.class);
try {
// System.out.println("code:\n" + responseData.getHeader().get("code"));
if (0 == responseData.getHeader().get("code").getAsInt()) {
System.out.println("###########");
System.out.println("getStatus: " + responseData.getHeader().get("status").getAsInt());
if (2 != responseData.getHeader().get("status").getAsInt()) {
System.out.println("****************");
Payload pl = json.fromJson(responseData.getPayload(), Payload.class);
JsonArray temp = (JsonArray) pl.getChoices().get("text");
JsonObject jo = (JsonObject) temp.get(0);
answer += jo.get("content").getAsString();
answerBuf += jo.get("content").getAsString();
synchronized (this) {
ResponseData responseData = json.fromJson(text,ResponseData.class);
try {
// System.out.println("code:\n" + responseData.getHeader().get("code"));
if (0 == responseData.getHeader().get("code").getAsInt()) {
System.out.println("###########");
System.out.println("getStatus: " + responseData.getHeader().get("status").getAsInt());
if (2 != responseData.getHeader().get("status").getAsInt()) {
System.out.println("****************");
Payload pl = json.fromJson(responseData.getPayload(), Payload.class);
JsonArray temp = (JsonArray) pl.getChoices().get("text");
JsonObject jo = (JsonObject) temp.get(0);
answer += jo.get("content").getAsString();
answerBuf += jo.get("content").getAsString();
// System.out.println(answer);
} else {
Payload pl1 = json.fromJson(responseData.getPayload(), Payload.class);
JsonObject jsonObject = (JsonObject) pl1.getUsage().get("text");
int prompt_tokens = jsonObject.get("prompt_tokens").getAsInt();
JsonArray temp1 = (JsonArray) pl1.getChoices().get("text");
JsonObject jo = (JsonObject) temp1.get(0);
answer += jo.get("content").getAsString();
answerBuf += jo.get("content").getAsString();
System.out.println("返回结果为:\n" + answer);

if (INSTANCE.redisTemplate.hasKey("gpt:websocket:1")) {
DmWebSocketMessageVo message = (DmWebSocketMessageVo) INSTANCE.redisTemplate.opsForValue().get("gpt:websocket:1");
if (message != null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) {
JSONObject birthdayJo = new JSONObject();
birthdayJo.put("content", answer);
birthdayJo.put("timestamp", message.getFormat().get("timestamp"));
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "birthday", birthdayJo.toString());
INSTANCE.redisTemplate.delete("gpt:websocket:1");
return;
}
if (message!= null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hireDate")) {
JSONObject birthdayJo = new JSONObject();
birthdayJo.put("content", answer);
birthdayJo.put("timestamp", message.getFormat().get("timestamp"));
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "hireDate", birthdayJo.toString());
} else {
Payload pl1 = json.fromJson(responseData.getPayload(), Payload.class);
JsonObject jsonObject = (JsonObject) pl1.getUsage().get("text");
int prompt_tokens = jsonObject.get("prompt_tokens").getAsInt();
JsonArray temp1 = (JsonArray) pl1.getChoices().get("text");
JsonObject jo = (JsonObject) temp1.get(0);
answer += jo.get("content").getAsString();
answerBuf += jo.get("content").getAsString();
System.out.println("返回结果为:\n" + answer);

if (INSTANCE.redisTemplate.hasKey("gpt:websocket:1")) {
DmWebSocketMessageVo message = (DmWebSocketMessageVo) INSTANCE.redisTemplate.opsForValue().get("gpt:websocket:1");
if (message != null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) {
JSONObject birthdayJo = new JSONObject();
birthdayJo.put("content", answer);
birthdayJo.put("timestamp", message.getFormat().get("timestamp"));
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "birthday", birthdayJo.toString());
INSTANCE.redisTemplate.delete("gpt:websocket:1");
this.notifyAll();
return;
}
if (message!= null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hireDate")) {
JSONObject birthdayJo = new JSONObject();
birthdayJo.put("content", answer);
birthdayJo.put("timestamp", message.getFormat().get("timestamp"));
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "hireDate", birthdayJo.toString());
INSTANCE.redisTemplate.delete("gpt:websocket:1");
this.notifyAll();
return;
}
if (message != null) {
JSONObject preWebsocketJo = message.getFormat();
JSONObject meetingJo = new JSONObject();
meetingJo.put("timestamp",preWebsocketJo.get("timestamp"));
meetingJo.put("content",answer);
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + preWebsocketJo.getString("orderId"), "meeting", meetingJo.toString());
this.notifyAll();
return;
}
INSTANCE.redisTemplate.delete("gpt:websocket:1");
return;
// 清除systemRole
systemRole = "";

}else {
// 添加上下文
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", questions.get(questions.size() - 1));
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", answer);
// 添加缓存
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", answer);
}
if (message != null) {
JSONObject preWebsocketJo = message.getFormat();
JSONObject meetingJo = new JSONObject();
meetingJo.put("timestamp",preWebsocketJo.get("timestamp"));
meetingJo.put("content",answer);
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + preWebsocketJo.getString("orderId"), "meeting", meetingJo.toString());
this.notifyAll();
// webSocket.close(3,"客户端主动断开链接");
//webSocket.close(1000,"客户端主动断开链接");
}
if (this.stream && !StringUtils.isEmpty(curUserId)) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
logger.info("当前ch:{}",ch.id().asLongText());
if (ch != null) {
List<String> ttsList = new ArrayList<>();
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
//去除转义符
answerBuf = answerBuf.replaceAll("[\\r\\n]", "");
//去除引号
answerBuf = answerBuf.replaceAll("\"", "");
// 处理answer,如果包含"。",则将"。"之前的内容发送给前端
while(answerBuf.contains("。") || answerBuf.contains("?") || answerBuf.contains("!") ||
answerBuf.contains("?") || answerBuf.contains("!")) {
String[] temp = answerBuf.split("。|?|!|\\?|\\!");
ttsList.add(temp[0] + answerBuf.charAt(temp[0].length()));
answerBuf = answerBuf.substring(temp[0].length() + 1);
}
if (2 == responseData.getHeader().get("status").getAsInt() && CollectionUtils.isEmpty(ttsList)) {
jo.put("tts",answerBuf);
jo.put("status",2);
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString()));
} else {
for (int i = 0;i <ttsList.size();i++) {
if (2 == responseData.getHeader().get("status").getAsInt() && i == ttsList.size() - 1) {
jo.put("status",2);
} else {
jo.put("status",1);
}
jo.put("tts",ttsList.get(i));
String str = jo.toJSONString();
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
}
}
}
INSTANCE.redisTemplate.delete("gpt:websocket:1");
// 清除systemRole
systemRole = "";

}else {
// 添加上下文
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", questions.get(questions.size() - 1));
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", answer);
// 添加缓存
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", answer);
}

// webSocket.close(3,"客户端主动断开链接");
//webSocket.close(1000,"客户端主动断开链接");
}
if (this.stream && !StringUtils.isEmpty(curUserId)) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
logger.info("当前ch:{}",ch.id().asLongText());
if (ch != null) {
List<String> ttsList = new ArrayList<>();
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
//去除转义符
answerBuf = answerBuf.replaceAll("[\\r\\n]", "");
//去除引号
answerBuf = answerBuf.replaceAll("\"", "");
// 处理answer,如果包含"。",则将"。"之前的内容发送给前端
while(answerBuf.contains("。") || answerBuf.contains("?") || answerBuf.contains("!") ||
answerBuf.contains("?") || answerBuf.contains("!")) {
String[] temp = answerBuf.split("。|?|!|\\?|\\!");
ttsList.add(temp[0] + answerBuf.charAt(temp[0].length()));
answerBuf = answerBuf.substring(temp[0].length() + 1);
}
if (2 == responseData.getHeader().get("status").getAsInt() && CollectionUtils.isEmpty(ttsList)) {
jo.put("tts",answerBuf);
} else {
// 添加缓存
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", "-1");
// 判断流式则返回结束状态
if (stream == true) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
if (ch != null) {
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
jo.put("status",2);
jo.put("tts","抱歉,您的问题我无法解答。");
String str = jo.toJSONString();
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString()));
} else {
for (int i = 0;i <ttsList.size();i++) {
if (2 == responseData.getHeader().get("status").getAsInt() && i == ttsList.size() - 1) {
jo.put("status",2);
} else {
jo.put("status",1);
}
jo.put("tts",ttsList.get(i));
String str = jo.toJSONString();
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
}
ch.writeAndFlush(new TextWebSocketFrame(str));
}
}
System.out.println("返回结果错误:\n" + responseData.getHeader().get("code") + responseData.getHeader().get("message"));
this.notifyAll();
}

} else {
// 添加缓存
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", "-1");
// 判断流式则返回结束状态
if (stream == true) {
} catch (Exception e) {
if (StringUtils.isNotEmpty(curUserId)) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
if (ch != null) {
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
jo.put("status",2);
jo.put("tts","抱歉,您的问题我无法解答。");
String str = jo.toJSONString();
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
}
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
jo.put("status",2);
jo.put("tts","大模型出现异常,请稍后重试。");
String str = jo.toJSONString();
logger.info("发生异常client:{},内容:{}",curUserId,jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
}
System.out.println("返回结果错误:\n" + responseData.getHeader().get("code") + responseData.getHeader().get("message"));
LOCK.notifyAll();
}
} catch (Exception e) {
if (StringUtils.isNotEmpty(curUserId)) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
jo.put("status",2);
jo.put("tts","大模型出现异常,请稍后重试。");
String str = jo.toJSONString();
logger.info("发生异常client:{},内容:{}",curUserId,jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
e.printStackTrace();
this.notifyAll();
}
e.printStackTrace();
}

}




+ 17
- 9
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/controller/DmIntentController.java View File

@@ -11,6 +11,7 @@ import com.xueyi.common.core.constant.digitalman.SkillConstants.SkillType;
import com.xueyi.common.core.context.SecurityContextHolder;
import com.xueyi.common.core.utils.DateUtil;
import com.xueyi.common.core.utils.core.IdUtil;
import com.xueyi.common.core.utils.core.SpringUtils;
import com.xueyi.common.core.web.result.AjaxResult;
import com.xueyi.common.core.web.result.R;
import com.xueyi.common.core.web.validate.V_A;
@@ -34,6 +35,9 @@ import com.xueyi.nlt.api.nlt.feign.RemoteLandingLlmService;
import com.xueyi.nlt.api.nlt.feign.RemoteQAService;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.nlt.context.TerminalSecurityContextHolder;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
import com.xueyi.nlt.nlt.domain.LlmResponse;
import com.xueyi.nlt.nlt.domain.dto.DmIntentDto;
import com.xueyi.nlt.nlt.domain.po.DmRegularPo;
import com.xueyi.nlt.nlt.domain.query.DmIntentQuery;
@@ -41,6 +45,7 @@ import com.xueyi.nlt.nlt.domain.vo.IntentTemplateVo;
import com.xueyi.nlt.nlt.domain.vo.MarkRecordVo;
import com.xueyi.nlt.nlt.mapper.DmRegularMapper;
import com.xueyi.nlt.nlt.service.IDmIntentService;
import com.xueyi.nlt.nlt.service.ISysLlmService;
import com.xueyi.nlt.nlt.template.*;
import com.xueyi.system.api.digitalmans.domain.dto.DmManDeviceDto;
import com.xueyi.system.api.digitalmans.domain.dto.DmSkillDto;
@@ -160,6 +165,9 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt

@Autowired
private FlightMessageTemplate flightMessageTemplate;

@Autowired
private ISysLlmService sysLlmService;
/**
* 意图请求
列表
@@ -577,7 +585,6 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt
@PostMapping("/inner/sendMessage")
@ResponseBody
public R<Object> sendMessage(@RequestBody DmWebSocketMessageVo message) {

log.info("websocket sendMessage:{}", message);
if (message == null || message.getFormat() == null) {
return R.fail("参数为空");
@@ -589,27 +596,28 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt
String meetingRoom = jo.getString("meetingRoom");
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) {
String prefix = "假设你是一名公司前台,你看到" + message.getFormat().get("name")+ "已知今天是他的生日。请你从个人角度输出给他的生日贺词。要求待人平和,具有人情味,用词正式,内容与工作无关。输出只包含你要对他说的话,在20字以内。";
webSocketClient.sendMsg(prefix);
redisTemplate2.opsForValue().set("gpt:websocket" + ":" + "1", message);
LlmContext context = new LlmContext(prefix);
LlmResponse response = sysLlmService.chat(context, new LlmParam());
redisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "birthday", response.getContent());
return R.ok();
}
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hireDate")) {
String prefix = "假设你是一名公司前台,你看到"+ message.getFormat().get("name")+ ",已知今天是他入职" + message.getFormat().get("years")+"周年,请你从个人角度说出对他入职周年的祝贺。要求具有人情味,有特色,字数在25字左右,不要提到生日,要带人名。输出只包含你要对他说的话。";
webSocketClient.sendMsg(prefix);
redisTemplate2.opsForValue().set("gpt:websocket" + ":" + "1", message);
LlmContext context = new LlmContext(prefix);
LlmResponse response = sysLlmService.chat(context, new LlmParam());
redisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "hireDate", response.getContent());
return R.ok();
}
Date date = new Date(timestamp.longValue());
if (message.getSkillCode().equals("1")) {
String prefix = "假设你是一名公司前台,你看到在你们公司工作的\\"+ jo.getString("orderName")+ "\\,请你从个人角度提醒他参加\\" +
dateFormat4.format(timestamp) + "\\在\\" + meetingRoom + "\\的会,要求语气友好。输出只包含你要对他说的话,在20字左右。";
webSocketClient.sendMsg(prefix);
// 设置缓存
redisTemplate2.opsForValue().set("gpt:websocket" + ":" + "1", message);
LlmContext context = new LlmContext(prefix);
LlmResponse response = sysLlmService.chat(context, new LlmParam());
redisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "meeting", response.getContent());
}
return R.ok();
}

/**
* 意图请求
列表


+ 6
- 3
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/service/impl/SparkServiceImpl.java View File

@@ -1,5 +1,6 @@
package com.xueyi.nlt.nlt.service.impl;

import com.xueyi.common.core.utils.core.SpringUtils;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.nlt.domain.LlmContent;
import com.xueyi.nlt.nlt.domain.LlmContext;
@@ -27,16 +28,18 @@ public class SparkServiceImpl implements ISysLlmService {
@Override
public LlmResponse chat(LlmContext context, LlmParam param) {
List<String> contentArr = context.getContentList().stream().map(LlmContent::getContent).collect(Collectors.toList());
synchronized (WebSocketClient.LOCK) {
webSocketClient.sendMsg(contentArr);
WebSocketClient socketClient = SpringUtils.getBean(WebSocketClient.class);
webSocketClient = socketClient.sendMsg(contentArr);
synchronized (webSocketClient) {
try {
WebSocketClient.LOCK.wait();
webSocketClient.wait();

} catch (InterruptedException e) {
e.printStackTrace();
Thread.currentThread().interrupt();
}
String result = redisTemplate.opsForValue().get("group:websocket:content");
result = webSocketClient.answer;
LlmResponse response = new LlmResponse();
response.setContent(result);
return response;


Loading…
Cancel
Save