Преглед на файлове

yinruoxi

修改:
    1.重构websocketClient
dev-rebuild-websocket
kira преди 1 година
родител
ревизия
4f6a5f1524
променени са 13 файла, в които са добавени 606 реда и са изтрити 688 реда
  1. +0
    -521
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/WebSocketClient.java
  2. +131
    -0
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/WebSocketClientManager.java
  3. +325
    -0
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/listener/LlmWebSocketListener.java
  4. +2
    -29
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/controller/DmWebsocketController.java
  5. +0
    -2
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/server/config/ServerConfig.java
  6. +5
    -8
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/controller/DmIntentController.java
  7. +10
    -1
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/domain/LlmContext.java
  8. +15
    -2
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/domain/LlmParam.java
  9. +25
    -11
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/service/impl/SparkServiceImpl.java
  10. +64
    -70
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/FreeChatTemplate.java
  11. +25
    -28
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/GenerativeKnowledgeTemplate.java
  12. +2
    -14
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/MeetingOrderTemplate.java
  13. +2
    -2
      xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/MovieChatTemplate.java

+ 0
- 521
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/WebSocketClient.java Целия файл

@@ -1,521 +0,0 @@
package com.xueyi.nlt.netty.client;

import cn.hutool.core.lang.Snowflake;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.nacos.shaded.com.google.gson.Gson;
import com.alibaba.nacos.shaded.com.google.gson.JsonArray;
import com.alibaba.nacos.shaded.com.google.gson.JsonObject;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.xueyi.common.core.utils.core.IdUtil;
import com.xueyi.common.web.interceptor.ApiRequestInterceptor;
import com.xueyi.nlt.api.netty.domain.vo.DmWebSocketMessageVo;
import com.xueyi.nlt.netty.server.config.ServerConfig;
import com.xueyi.nlt.nlt.domain.vo.LlmQueryVo;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import okhttp3.*;
import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.io.Serializable;
import java.net.URL;
import java.nio.charset.Charset;
import java.text.SimpleDateFormat;
import java.util.*;

@Component
public class WebSocketClient extends WebSocketListener {

private static final Logger logger = LoggerFactory.getLogger(WebSocketClient.class);
public static WebSocketClient INSTANCE;
@Autowired
private RedisTemplate redisTemplate;
@Autowired
private StringRedisTemplate stringRedisTemplate;


@Value("${secret.spark.appId}")
private String appId;

@Value("${secret.spark.apiSecret}")
private String apiSecret;

@Value("${secret.spark.apiKey}")
private String apiKey;
@Value("${secret.spark.hostUrl}")
public String hostUrl;
public final static Object LOCK = new Object();

// public static String APPID = "3d9282da";//从开放平台控制台中获取
// public static String APIKEY = "7c217b3a313f4b66fcc14a8e97f85103";//从开放平台控制台中获取
// public static String APISecret = "ZTRiNDQwMTRlOTlmZDQwMDUwYTdjMDM0";//从开放平台控制台中获取

public WebSocket webSocket;
public static final Gson json = new Gson();
// public static String question = "假设你是一位前台,你需要通过与其他人对话来获取会议相关信息,已知今天是2023-7-19,你需要获取会议日期,开始时间,持续时间,会议地点,会议主题。时间用类似00:00的格式输出。对方的话中可能不包含全部信息,对于未知的信息填充为none。如果所有信息都已知那么commit为true。否则为false。将你获得的信息输出为json格式。对方的话是:“明天下午开个会。从两点开到下午三点,在大会议室开,主题是访客接待”,只输出最后的json。输出只有一行,输出格式为{date:,start_time:,duration:,location:,theme:commit:}。";//可以修改question 内容,来向模型提问
// 定义内存共享变量traceId
public Long traceId;
public String question = "请帮我安排五一出行计划";//可以修改question 内容,来向模型提问
public String systemRole = "";
public List<String> questions = new ArrayList<>();//可以修改question 内容,来向模型提问

public boolean stream = false;
public String curUserId = null;
public String answer = "";
public String answerBuf = "";


@PostConstruct
public void init() {
INSTANCE = this;
INSTANCE.redisTemplate = this.redisTemplate;
INSTANCE.stringRedisTemplate = this.stringRedisTemplate;
INSTANCE.appId = this.appId;
}

public static void main(String[] args) {
synchronized (LOCK) {
try {
//构建鉴权httpurl
String authUrl = getAuthorizationUrl("https://spark-api.xf-yun.com/v3.1/chat", "54f6e81f40a31d66d976496de895a7a4", "ZDYyMjNmMTlkYTE0YWRmOWUwZTYxNjYz");
OkHttpClient okHttpClient = new OkHttpClient.Builder().build();
String url = authUrl.replace("https://","wss://").replace("http://","ws://");
Request request = new Request.Builder().url(url).build();
WebSocket webSocket = okHttpClient.newWebSocket(request,new WebSocketClient());
LOCK.wait();
System.out.println("查询完成");
} catch (InterruptedException ie) {
ie.printStackTrace();
Thread.currentThread().interrupt();
} catch (Exception e) {
e.printStackTrace();
}
}



// write your code here

}

/**
* 调用讯飞开放平台接口发送消息,并将消息存入redis队列
* @param message
* @param key
*/
public void sendMsg(String message,String key){
LlmQueryVo vo = new LlmQueryVo();
vo.setQuestion(message);
vo.setTemplate(key);
redisTemplate.opsForList().rightPush("group:websocket:quary",vo);
sendMsg(message);
}
public void sendMsg(String message){
question = message;

try {
//构建鉴权httpurl
String authUrl = getAuthorizationUrl(hostUrl,apiKey,apiSecret);
OkHttpClient okHttpClient = new OkHttpClient.Builder().build();
String url = authUrl.replace("https://","wss://").replace("http://","ws://");
Request request = new Request.Builder().url(url).build();
webSocket = okHttpClient.newWebSocket(request,new WebSocketClient());

} catch (Exception e) {
e.printStackTrace();
}
}
public WebSocketClient sendMsg(List<String> messages){
return this.sendMsg(messages,false, "null");
}
public WebSocketClient sendMsg(List<String> messages,boolean stream,String userId){
this.stream = stream;
this.curUserId = userId;
if (messages.size() / 2 > 0) {
systemRole = messages.get(0);
messages.remove(0);
}
questions = messages;
question = null;

try {
//构建鉴权httpurl
String authUrl = getAuthorizationUrl(hostUrl,apiKey,apiSecret);
OkHttpClient okHttpClient = new OkHttpClient.Builder().build();
String url = authUrl.replace("https://","wss://").replace("http://","ws://");
Request request = new Request.Builder().url(url).build();
WebSocketClient wsc = new WebSocketClient();
wsc.stream = stream;
wsc.curUserId = userId;
wsc.questions = questions;
wsc.question = question;
wsc.systemRole = systemRole;
wsc.traceId = IdUtil.getSnowflakeNextId();
ServerConfig.currentTraceMap.put(curUserId,wsc.traceId);
System.out.println("wocket客户端:" + wsc.hashCode());
webSocket = okHttpClient.newWebSocket(request,wsc);
return wsc;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
//鉴权url
public static String getAuthorizationUrl(String hostUrl , String apikey ,String apisecret) throws Exception {
//获取host
URL url = new URL(hostUrl);
//获取鉴权时间 date
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
System.out.println("format:\n" + format );
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
//获取signature_origin字段
StringBuilder builder = new StringBuilder("host: ").append(url.getHost()).append("\n").
append("date: ").append(date).append("\n").
append("GET ").append(url.getPath()).append(" HTTP/1.1");
System.out.println("signature_origin:\n" + builder);
//获得signatue
Charset charset = Charset.forName("UTF-8");
Mac mac = Mac.getInstance("hmacsha256");
SecretKeySpec sp = new SecretKeySpec(apisecret.getBytes(charset),"hmacsha256");
mac.init(sp);
byte[] basebefore = mac.doFinal(builder.toString().getBytes(charset));
String signature = Base64.getEncoder().encodeToString(basebefore);
//获得 authorization_origin
String authorization_origin = String.format("api_key=\"%s\",algorithm=\"%s\",headers=\"%s\",signature=\"%s\"",apikey,"hmac-sha256","host date request-line",signature);
//获得authorization
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes(charset));
//获取httpurl
HttpUrl httpUrl = HttpUrl.parse("https://" + url.getHost() + url.getPath()).newBuilder().//
addQueryParameter("authorization", authorization).//
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();

return httpUrl.toString();
}

//重写onopen
@Override
public void onOpen(WebSocket webSocket, Response response) {
super.onOpen(webSocket, response);
new Thread(()->{
JsonObject frame = new JsonObject();
JsonObject header = new JsonObject();
JsonObject chat = new JsonObject();
JsonObject parameter = new JsonObject();
JsonObject payload = new JsonObject();
JsonObject message = new JsonObject();
JsonObject text = new JsonObject();
JsonArray ja = new JsonArray();

//填充header
header.addProperty("app_id",INSTANCE.appId);
header.addProperty("uid","123456789");
//填充parameter
// chat.addProperty("domain","general"); //1.0版本
chat.addProperty("domain","generalv3"); // 3.0版本
chat.addProperty("random_threshold",0.5);
chat.addProperty("max_tokens",1024);
chat.addProperty("auditing","default");
parameter.add("chat",chat);
if (!StringUtils.isEmpty(systemRole)) {
text = new JsonObject();
//填充payload
text.addProperty("role","system");
text.addProperty("content",systemRole);
ja.add(text);
}
if (!StringUtils.isEmpty(question)) {
text = new JsonObject();
//填充payload
text.addProperty("role","user");
text.addProperty("content",question);
ja.add(text);
}else {
for (int i = 0;i < questions.size();i++) {

text = new JsonObject();
if (i % 2 == 0) {
text.addProperty("role","user");

} else {
text.addProperty("role","assistant");
}
text.addProperty("content",questions.get(i));
System.out.println(text.toString());
ja.add(text);
}
}
// message.addProperty("text",ja.getAsString());
message.add("text",ja);
payload.add("message",message);
frame.add("header",header);
frame.add("parameter",parameter);
frame.add("payload",payload);
System.out.println("frame:\n" + frame.toString());
webSocket.send(frame.toString());


}


).start();
}

//重写onmessage

@Override
public void onMessage(WebSocket webSocket, String text) {
super.onMessage(webSocket, text);
System.out.println("text:\n" + text);
if (!StringUtils.isEmpty(curUserId)) {
if (ServerConfig.currentTraceMap.containsKey(curUserId) && !ServerConfig.currentTraceMap.get(curUserId).equals(traceId)) {
return;
}
}
synchronized (this) {
ResponseData responseData = json.fromJson(text,ResponseData.class);
try {
// System.out.println("code:\n" + responseData.getHeader().get("code"));
if (0 == responseData.getHeader().get("code").getAsInt()) {
System.out.println("###########");
System.out.println("getStatus: " + responseData.getHeader().get("status").getAsInt());
if (2 != responseData.getHeader().get("status").getAsInt()) {
System.out.println("****************");
Payload pl = json.fromJson(responseData.getPayload(), Payload.class);
JsonArray temp = (JsonArray) pl.getChoices().get("text");
JsonObject jo = (JsonObject) temp.get(0);
answer += jo.get("content").getAsString();
answerBuf += jo.get("content").getAsString();
// System.out.println(answer);
} else {
Payload pl1 = json.fromJson(responseData.getPayload(), Payload.class);
JsonObject jsonObject = (JsonObject) pl1.getUsage().get("text");
int prompt_tokens = jsonObject.get("prompt_tokens").getAsInt();
JsonArray temp1 = (JsonArray) pl1.getChoices().get("text");
JsonObject jo = (JsonObject) temp1.get(0);
answer += jo.get("content").getAsString();
answerBuf += jo.get("content").getAsString();
System.out.println("返回结果为:\n" + answer);

if (INSTANCE.redisTemplate.hasKey("gpt:websocket:1")) {
DmWebSocketMessageVo message = (DmWebSocketMessageVo) INSTANCE.redisTemplate.opsForValue().get("gpt:websocket:1");
if (message != null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) {
JSONObject birthdayJo = new JSONObject();
birthdayJo.put("content", answer);
birthdayJo.put("timestamp", message.getFormat().get("timestamp"));
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "birthday", birthdayJo.toString());
INSTANCE.redisTemplate.delete("gpt:websocket:1");
this.notifyAll();
return;
}
if (message!= null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hireDate")) {
JSONObject birthdayJo = new JSONObject();
birthdayJo.put("content", answer);
birthdayJo.put("timestamp", message.getFormat().get("timestamp"));
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "hireDate", birthdayJo.toString());
INSTANCE.redisTemplate.delete("gpt:websocket:1");
this.notifyAll();
return;
}
if (message != null) {
JSONObject preWebsocketJo = message.getFormat();
JSONObject meetingJo = new JSONObject();
meetingJo.put("timestamp",preWebsocketJo.get("timestamp"));
meetingJo.put("content",answer);
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + preWebsocketJo.getString("orderId"), "meeting", meetingJo.toString());
this.notifyAll();
return;
}
INSTANCE.redisTemplate.delete("gpt:websocket:1");
// 清除systemRole
systemRole = "";

}else {
// 添加上下文
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", questions.get(questions.size() - 1));
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", answer);
// 添加缓存
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", answer);
}
this.notifyAll();
// webSocket.close(3,"客户端主动断开链接");
//webSocket.close(1000,"客户端主动断开链接");
}
if (this.stream && !StringUtils.isEmpty(curUserId)) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
logger.info("当前ch:{}",ch.id().asLongText());
if (ch != null) {
List<String> ttsList = new ArrayList<>();
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
//去除转义符
answerBuf = answerBuf.replaceAll("[\\r\\n]", "");
//去除引号
answerBuf = answerBuf.replaceAll("\"", "");
// 处理answer,如果包含"。",则将"。"之前的内容发送给前端
while(answerBuf.contains("。") || answerBuf.contains("?") || answerBuf.contains("!") ||
answerBuf.contains("?") || answerBuf.contains("!")) {
String[] temp = answerBuf.split("。|?|!|\\?|\\!");
ttsList.add(temp[0] + answerBuf.charAt(temp[0].length()));
answerBuf = answerBuf.substring(temp[0].length() + 1);
}
if (2 == responseData.getHeader().get("status").getAsInt() && CollectionUtils.isEmpty(ttsList)) {
jo.put("tts",answerBuf);
jo.put("status",2);
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString()));
} else {
for (int i = 0;i <ttsList.size();i++) {
if (2 == responseData.getHeader().get("status").getAsInt() && i == ttsList.size() - 1) {
jo.put("status",2);
} else {
jo.put("status",1);
}
jo.put("tts",ttsList.get(i));
String str = jo.toJSONString();
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
}
}
}
}

} else {
// 添加缓存
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", "-1");
// 判断流式则返回结束状态
if (stream == true) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
if (ch != null) {
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
jo.put("status",2);
jo.put("tts","抱歉,您的问题我无法解答。");
String str = jo.toJSONString();
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
}
}
System.out.println("返回结果错误:\n" + responseData.getHeader().get("code") + responseData.getHeader().get("message"));
this.notifyAll();
}
} catch (Exception e) {
if (StringUtils.isNotEmpty(curUserId)) {
Channel ch = ServerConfig.sessionMap.get(curUserId);
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",traceId);
jo.put("status",2);
jo.put("tts","大模型出现异常,请稍后重试。");
String str = jo.toJSONString();
logger.info("发生异常client:{},内容:{}",curUserId,jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(str));
}
e.printStackTrace();
this.notifyAll();
}
}

}


//重写onFailure

@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
super.onFailure(webSocket, t, response);
System.out.println(response);
}



class ResponseData{
private JsonObject header;
private JsonObject payload;

public JsonObject getHeader() {
return header;
}

public JsonObject getPayload() {
return payload;
}
}

class Header{
private int code ;
private String message;
private String sid;
private String status;

public int getCode() {
return code;
}

public String getMessage() {
return message;
}

public String getSid() {
return sid;
}

public String getStatus() {
return status;
}
}

class Payload{
private JsonObject choices;
private JsonObject usage;

public JsonObject getChoices() {
return choices;
}

public JsonObject getUsage() {
return usage;
}
}

class Choices{
private int status;
private int seq;
private JsonArray text;

public int getStatus() {
return status;
}

public int getSeq() {
return seq;
}

public JsonArray getText() {
return text;
}
}






}

+ 131
- 0
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/WebSocketClientManager.java Целия файл

@@ -0,0 +1,131 @@
package com.xueyi.nlt.netty.client;

import com.xueyi.nlt.netty.client.listener.LlmWebSocketListener;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
import okhttp3.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.net.URL;
import java.nio.charset.Charset;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

@Component
public class WebSocketClientManager {

private static final Logger logger = LoggerFactory.getLogger(WebSocketClientManager.class);

private static final int MAX_CONCURRENT_CONNECTIONS = 5;
private static final int CORE_POOL_SIZE = 2;
private static final int MAX_POOL_SIZE = 4;
private static final long KEEP_ALIVE_TIME = 60L;

private ExecutorService connectionPool;
private OkHttpClient client;

@Value("${secret.spark.apiSecret}")
private String apiSecret;

@Value("${secret.spark.apiKey}")
private String apiKey;
@Value("${secret.spark.hostUrl}")
public String hostUrl;


@PostConstruct
public void init() {
ConnectionPool okHttpConnectionPool = new ConnectionPool(MAX_CONCURRENT_CONNECTIONS, KEEP_ALIVE_TIME, TimeUnit.SECONDS);
client = new OkHttpClient.Builder().connectionPool(okHttpConnectionPool).build();
connectionPool = new ThreadPoolExecutor(CORE_POOL_SIZE, MAX_POOL_SIZE, KEEP_ALIVE_TIME, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
}


public static void main(String[] args) {
LlmContext context = new LlmContext("今天北京天气怎么样");
LlmParam param = new LlmParam();
LlmWebSocketListener listener = new LlmWebSocketListener("12345", param, context,false);
synchronized (listener) {
try {
//构建鉴权httpurl
String authUrl = getAuthorizationUrl("https://spark-api.xf-yun.com/v3.1/chat", "54f6e81f40a31d66d976496de895a7a4", "ZDYyMjNmMTlkYTE0YWRmOWUwZTYxNjYz");
OkHttpClient okHttpClient = new OkHttpClient.Builder().build();
String url = authUrl.replace("https://","wss://").replace("http://","ws://");
Request request = new Request.Builder().url(url).build();
WebSocket webSocket = okHttpClient.newWebSocket(request,listener);
listener.wait();
System.out.println("查询完成");
} catch (InterruptedException ie) {
ie.printStackTrace();
Thread.currentThread().interrupt();
} catch (Exception e) {
e.printStackTrace();
}
}
}

/**
* 调用讯飞开放平台接口发送消息
* @param listener
*/

public void startWebSocketClient(LlmWebSocketListener listener) {
connectionPool.execute(() -> {
try {
String authUrl = getAuthorizationUrl(hostUrl, apiKey, apiSecret);
String url = authUrl.replace("https://", "wss://").replace("http://", "ws://");
Request request = new Request.Builder().url(url).build();
WebSocket webSocket = client.newWebSocket(request, listener);
} catch (Exception e) {
logger.error("startWebSocketClient error", e.getStackTrace());
}
});
}

//鉴权url
public static String getAuthorizationUrl(String hostUrl , String apikey ,String apisecret) throws Exception {
//获取host
URL url = new URL(hostUrl);
//获取鉴权时间 date
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
System.out.println("format:\n" + format );
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
//获取signature_origin字段
StringBuilder builder = new StringBuilder("host: ").append(url.getHost()).append("\n").
append("date: ").append(date).append("\n").
append("GET ").append(url.getPath()).append(" HTTP/1.1");
System.out.println("signature_origin:\n" + builder);
//获得signatue
Charset charset = Charset.forName("UTF-8");
Mac mac = Mac.getInstance("hmacsha256");
SecretKeySpec sp = new SecretKeySpec(apisecret.getBytes(charset),"hmacsha256");
mac.init(sp);
byte[] basebefore = mac.doFinal(builder.toString().getBytes(charset));
String signature = Base64.getEncoder().encodeToString(basebefore);
//获得 authorization_origin
String authorization_origin = String.format("api_key=\"%s\",algorithm=\"%s\",headers=\"%s\",signature=\"%s\"",apikey,"hmac-sha256","host date request-line",signature);
//获得authorization
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes(charset));
//获取httpurl
HttpUrl httpUrl = HttpUrl.parse("https://" + url.getHost() + url.getPath()).newBuilder().//
addQueryParameter("authorization", authorization).//
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();

return httpUrl.toString();
}

}

+ 325
- 0
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/client/listener/LlmWebSocketListener.java Целия файл

@@ -0,0 +1,325 @@
package com.xueyi.nlt.netty.client.listener;

import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.nacos.shaded.com.google.gson.Gson;
import com.alibaba.nacos.shaded.com.google.gson.JsonArray;
import com.alibaba.nacos.shaded.com.google.gson.JsonElement;
import com.alibaba.nacos.shaded.com.google.gson.JsonObject;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.xueyi.nlt.netty.server.config.ServerConfig;
import com.xueyi.nlt.nlt.domain.LlmContent;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;

public class LlmWebSocketListener extends WebSocketListener {

private static final Logger logger = LoggerFactory.getLogger(LlmWebSocketListener.class);

private static final Integer SPARK_RESPONSE_STATUS_FIRST_RESULT = 0;
private static final Integer SPARK_RESPONSE_STATUS_INTERMEDIATE_RESULT = 1;
private static final Integer SPARK_RESPONSE_STATUS_LAST_RESULT = 2;
private static final Double RANDOM_THRESHOLD = 0.5;
protected String appId;

protected LlmParam llmParam;

protected LlmContext llmContext;

public String systemRole = "";
// 是否为流式
protected boolean stream = false;

// 是否关闭websocket
private Boolean wsCloseFlag = false;


public String question = "";//可以修改question 内容,来向模型提问

public List<String> questions = new ArrayList<>();//可以修改question 内容,来向模型提问

public String answer = "";
public String answerBuf = "";

public static final Gson json = new Gson();

public LlmWebSocketListener(String appId, LlmParam llmParam, LlmContext context, boolean stream) {
this.appId = appId;
this.llmParam = llmParam;
this.llmContext = context;
this.stream = stream;
}

@Override
public void onOpen(WebSocket webSocket, Response response) {
super.onOpen(webSocket, response);
new Thread(()->{
try {
JsonObject frame = new JsonObject();
JsonObject header = new JsonObject();
JsonObject chat = new JsonObject();
JsonObject parameter = new JsonObject();
JsonObject payload = new JsonObject();
JsonObject message = new JsonObject();
JsonObject text = new JsonObject();
JsonArray ja = new JsonArray();

//填充header
header.addProperty("app_id",appId);
header.addProperty("uid","123456789");
//填充parameter
// chat.addProperty("domain","general"); //1.0版本
chat.addProperty("domain",llmParam.getModel()); // 3.0版本
chat.addProperty("temperature",llmParam.getTemperature());
chat.addProperty("max_tokens",llmParam.getMaxTokens());
parameter.add("chat",chat);

// 插入大模型上下文
for (LlmContent llmContent : llmContext.getContentList()) {
JsonElement jsonTree = json.toJsonTree(llmContent, LlmContent.class);
ja.add(jsonTree);
}
// message.addProperty("text",ja.getAsString());
message.add("text",ja);
payload.add("message",message);
frame.add("header",header);
frame.add("parameter",parameter);
frame.add("payload",payload);
System.out.println("frame:\n" + frame.toString());
webSocket.send(frame.toString());
// 等待服务端返回完毕后关闭
while (true) {
// System.err.println(wsCloseFlag + "---");
Thread.sleep(200);
if (wsCloseFlag) {
break;
}
}
webSocket.close(1000, "");
} catch (InterruptedException e) {
logger.error("websocket发送消息异常:{}",e.getMessage());
}

}
).start();
}

@Override
public void onMessage(WebSocket webSocket, String text) {
super.onMessage(webSocket, text);
System.out.println("text:\n" + text);
if (ServerConfig.currentTraceMap.containsKey(llmContext.getDevId()) && !ServerConfig.currentTraceMap.get(llmContext.getDevId()).equals(llmContext.getTraceId())) {
return;
}
ResponseData responseData = json.fromJson(text, ResponseData.class);

synchronized (this) {
// 如果返回的code不等于0,打印错误日志,设置websocket关闭标志位,设置answer为抱歉,您的问题我无法解答。如果为流式调用则向channel发送错误信息
if (0 != responseData.header.code) {
logger.error("返回结果错误:{}" , responseData);
this.wsCloseFlag = true;
this.answer = "抱歉,您的问题我无法解答。";
if (this.stream) {
Channel ch = ServerConfig.sessionMap.get(llmContext.getDevId());
if (ch != null) {
JSONObject jo = formatToChannel("抱歉,您的问题我无法解答。",SPARK_RESPONSE_STATUS_LAST_RESULT,responseData.header.code);
logger.info("发送到client:{},id:{},内容:{}",llmContext.getDevId(),ch.id().asLongText(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString()));
}
} else {
this.notifyAll();
}
return;
}
try {
System.out.println("###########");
System.out.println("getStatus: " + responseData.header.status);
List<Text> textList = responseData.payload.choices.text;
for (Text temp : textList) {
answer += temp.content;
answerBuf += temp.content;
}
if (stream) {
Channel ch = ServerConfig.sessionMap.get(llmContext.getDevId());
logger.info("当前ch:{}",ch.id().asLongText());
if (ch != null) {
// 向数字人推送流式消息,将answerBuf中的内容做拆分,如果包含"。",则将"。"之前的内容以列表形式发送给前端
List<String> ttsList = new ArrayList<>();
//去除转义符
answerBuf = answerBuf.replaceAll("[\\r\\n]", "");
//去除引号
answerBuf = answerBuf.replaceAll("\"", "");
// 处理answer,如果包含"。",则将"。"之前的内容发送给前端
while(answerBuf.contains("。") || answerBuf.contains("?") || answerBuf.contains("!") ||
answerBuf.contains("?") || answerBuf.contains("!")) {
String[] temp = answerBuf.split("。|?|!|\\?|\\!");
ttsList.add(temp[0] + answerBuf.charAt(temp[0].length()));
answerBuf = answerBuf.substring(temp[0].length() + 1);
}
for (int i = 0;i < ttsList.size();i++) {
JSONObject resJo;
if (i < ttsList.size() - 1) {
resJo = formatToChannel(ttsList.get(i),1,responseData.header.code);
} else {
resJo = formatToChannel(ttsList.get(i),responseData.header.status,responseData.header.code);
}
logger.info("发送到client:{},id:{},内容:{}",llmContext.getDevId(),ch.id().asLongText(),resJo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(resJo.toJSONString()));
}
}
}
// 判断当前状态是否为发送完成,如果为发送完成,则打印answer,设置websocket关闭标志位,如果为非流式调用,释放锁。
if (SPARK_RESPONSE_STATUS_LAST_RESULT == responseData.header.status) {
System.out.println("返回结果为:\n" + answer);
//关闭释放资源
this.wsCloseFlag = true;
// 如果是非流式调用,释放锁
if (!stream) {
this.notifyAll();
}

}
} catch (Exception e) {
logger.error("返回结果错误:{}" , e.getMessage());
this.answer = "大模型出现异常,请稍后重试。";
if (stream) {
Channel ch = ServerConfig.sessionMap.get(llmContext.getDevId());
if (ch != null) {
JSONObject jo = formatToChannel("大模型出现异常,请稍后重试。",SPARK_RESPONSE_STATUS_LAST_RESULT,-1);
logger.info("发生异常client:{},内容:{}",llmContext.getDevId(),jo.toJSONString());
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString()));
}
} else {
this.notifyAll();
}
}
}
}

@Override
public void onMessage(WebSocket webSocket, ByteString bytes) {
super.onMessage(webSocket, bytes);
}

@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
super.onClosing(webSocket, code, reason);
}

@Override
public void onClosed(WebSocket webSocket, int code, String reason) {
super.onClosed(webSocket, code, reason);
}

@Override
public void onFailure(WebSocket webSocket, Throwable t, @Nullable Response response) {
super.onFailure(webSocket, t, response);
}

// 定义一个局部方法,从SessionMap中获取获取当前用户的channel,向channel发送JSON格式的Response
public JSONObject formatToChannel(String msg, Integer status,Integer code) {
JSONObject jo = new JSONObject();
jo.put("action","chat");
jo.put("motion","idle");
jo.put("traceId",llmContext.getTraceId());
jo.put("status",status);
jo.put("tts",msg);
jo.put("code",code);
return jo;
}


class ResponseData{
private Header header;
private Payload payload;

public Header getHeader() {
return header;
}

public Payload getPayload() {
return payload;
}
}

class Header{
private int code ;
private String message;
private String sid;
private int status;

public int getCode() {
return code;
}

public String getMessage() {
return message;
}

public String getSid() {
return sid;
}

public int getStatus() {
return status;
}
}

class Payload{
private Choices choices;
private Useage usage;

public Choices getChoices() {
return choices;
}

public Useage getUsage() {
return usage;
}
}

class Choices{
int status;
int seq;
List<Text> text;

public int getStatus() {
return status;
}

public int getSeq() {
return seq;
}

public List<Text> getText() {
return text;
}
}
class Text {
String role;
String content;
}
class Useage{
/** 保留字段,可忽略 */
Integer question_tokens;
/** 包含历史问题的总tokens大小 */
Integer prompt_tokens;
/** 回答的tokens大小 */
Integer completion_tokens;
/** prompt_tokens和completion_tokens的和,也是本次交互计费的tokens大小 */
Integer total_tokens;
}
}

+ 2
- 29
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/controller/DmWebsocketController.java Целия файл

@@ -2,16 +2,9 @@ package com.xueyi.nlt.netty.controller;

import com.alibaba.fastjson2.JSONObject;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.xueyi.common.cache.utils.SourceUtil;
import com.xueyi.common.core.constant.basic.SecurityConstants;
import com.xueyi.common.core.web.result.AjaxResult;
import com.xueyi.common.core.web.result.R;
import com.xueyi.nlt.api.netty.domain.vo.DmWebSocketMessageVo;
import com.xueyi.nlt.api.nlt.domain.vo.DmIntentVo;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.nlt.domain.vo.IntentTemplateVo;
import com.xueyi.system.api.digitalmans.domain.dto.DmManDeviceDto;
import com.xueyi.system.api.model.Source;
import com.xueyi.nlt.netty.client.WebSocketClientManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
@@ -29,7 +22,7 @@ public class DmWebsocketController {
private static final Logger log = LoggerFactory.getLogger(DmWebsocketController.class);

@Autowired
WebSocketClient webSocketClient;
WebSocketClientManager webSocketClientManager;

@Autowired
StringRedisTemplate stringRedisTemplate;
@@ -54,26 +47,6 @@ public class DmWebsocketController {
SimpleDateFormat dateFormat3 = new SimpleDateFormat("MM-dd");
Double timestamp = Double.valueOf((String)jo.get("timestamp"));
String meetingRoom = jo.getString("meetingRoom");
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) {
String prefix = "假设你是一名公司前台,你看到" + message.getFormat().get("name")+ "已知今天是他的生日。请你从个人角度输出给他的生日贺词。要求待人平和,具有人情味,用词正式,内容与工作无关。输出只包含你要对他说的话,在20字以内。";
webSocketClient.sendMsg(prefix);
redisTemplate.opsForValue().set("gpt:websocket" + ":" + "1", message);
return R.ok();
}
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hiredate")) {
String prefix = "假设你是一名公司前台,你看到"+ message.getFormat().get("name")+ ",已知今天是他入职1周年,请你从个人角度说出对他入职周年的祝贺。要求具有人情味,有特色,字数在25字左右,不要提到生日,要带人名。输出只包含你要对他说的话。";
webSocketClient.sendMsg(prefix);
redisTemplate.opsForValue().set("gpt:websocket" + ":" + "1", message);
return R.ok();
}
Date date = new Date(timestamp.longValue());
if (message.getSkillCode().equals("1")) {
String prefix = "假设你是一名公司前台,你看到在你们公司工作的\\"+ jo.getString("orderName")+ "\\,请你从个人角度提醒他参加\\" +
dateFormat3.format(timestamp) + "\\在\\" + meetingRoom + "\\的会,要求语气友好。输出只包含你要对他说的话,在20字左右。";
webSocketClient.sendMsg(prefix);
// 设置缓存
redisTemplate.opsForValue().set("gpt:websocket" + ":" + "1", message);
}
return R.ok();
}
}

+ 0
- 2
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/netty/server/config/ServerConfig.java Целия файл

@@ -1,7 +1,5 @@
package com.xueyi.nlt.netty.server.config;

import com.alibaba.fastjson2.JSONObject;
import com.xueyi.nlt.netty.client.WebSocketClient;
import io.netty.channel.Channel;

import java.util.concurrent.ConcurrentHashMap;


+ 5
- 8
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/controller/DmIntentController.java Целия файл

@@ -11,7 +11,6 @@ import com.xueyi.common.core.constant.digitalman.SkillConstants.SkillType;
import com.xueyi.common.core.context.SecurityContextHolder;
import com.xueyi.common.core.utils.DateUtil;
import com.xueyi.common.core.utils.core.IdUtil;
import com.xueyi.common.core.utils.core.SpringUtils;
import com.xueyi.common.core.web.result.AjaxResult;
import com.xueyi.common.core.web.result.R;
import com.xueyi.common.core.web.validate.V_A;
@@ -33,7 +32,7 @@ import com.xueyi.nlt.api.nlt.domain.vo.response.DmKnowledgeResponse;
import com.xueyi.nlt.api.nlt.feign.RemoteIntentService;
import com.xueyi.nlt.api.nlt.feign.RemoteLandingLlmService;
import com.xueyi.nlt.api.nlt.feign.RemoteQAService;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.netty.client.WebSocketClientManager;
import com.xueyi.nlt.nlt.context.TerminalSecurityContextHolder;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
@@ -54,8 +53,6 @@ import com.xueyi.system.api.digitalmans.feign.RemoteDigitalmanService;
import com.xueyi.system.api.digitalmans.feign.RemoteManDeviceService;
import com.xueyi.system.api.digitalmans.feign.RemoteQuestionanswersService;
import com.xueyi.system.api.digitalmans.feign.RemoteSkillService;
import com.xueyi.system.api.interfaces.airport.domain.vo.PlaneMessageVo;
import com.xueyi.system.api.interfaces.airport.feign.RemotePlaneController;
import com.xueyi.system.api.model.Source;
import com.xueyi.system.api.organize.domain.dto.SysEnterpriseDto;
import com.xueyi.system.api.organize.feign.RemoteEnterpriseService;
@@ -64,7 +61,6 @@ import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.util.DigestUtils;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
@@ -77,9 +73,7 @@ import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import java.io.Serializable;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
@@ -99,7 +93,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt
IDmIntentService dmIntentService;

@Autowired
WebSocketClient webSocketClient;
WebSocketClientManager webSocketClientManager;

/** 定义节点名称 */
@Override
@@ -598,6 +592,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) {
String prefix = "假设你是一名公司前台,你看到" + message.getFormat().get("name")+ "已知今天是他的生日。请你从个人角度输出给他的生日贺词。要求待人平和,具有人情味,用词正式,内容与工作无关。输出只包含你要对他说的话,在20字以内。";
LlmContext context = new LlmContext(prefix);
context.setDevId(message.getDevId());
LlmResponse response = sysLlmService.chat(context, new LlmParam());
JSONObject birthdayJo = new JSONObject();
birthdayJo.put("content", response.getContent());
@@ -608,6 +603,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hireDate")) {
String prefix = "假设你是一名公司前台,你看到"+ message.getFormat().get("name")+ ",已知今天是他入职" + message.getFormat().get("years")+"周年,请你从个人角度说出对他入职周年的祝贺。要求具有人情味,有特色,字数在25字左右,不要提到生日,要带人名。输出只包含你要对他说的话。";
LlmContext context = new LlmContext(prefix);
context.setDevId(message.getDevId());
LlmResponse response = sysLlmService.chat(context, new LlmParam());
JSONObject hireDateJo = new JSONObject();
hireDateJo.put("content", response.getContent());
@@ -620,6 +616,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt
String prefix = "假设你是一名公司前台,你看到在你们公司工作的\\"+ jo.getString("orderName")+ "\\,请你从个人角度提醒他参加\\" +
dateFormat4.format(timestamp) + "\\在\\" + meetingRoom + "\\的会,要求语气友好。输出只包含你要对他说的话,在20字左右。";
LlmContext context = new LlmContext(prefix);
context.setDevId(message.getDevId());
LlmResponse response = sysLlmService.chat(context, new LlmParam());
JSONObject meetingJo = new JSONObject();
meetingJo.put("content", response.getContent());


+ 10
- 1
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/domain/LlmContext.java Целия файл

@@ -1,6 +1,7 @@
package com.xueyi.nlt.nlt.domain;

import com.alibaba.fastjson2.JSONArray;
import com.xueyi.common.core.utils.core.IdUtil;
import lombok.Data;
import lombok.NoArgsConstructor;

@@ -10,14 +11,22 @@ import java.util.List;
import java.util.stream.Collectors;

@Data
@NoArgsConstructor
public class LlmContext implements Serializable {

private String devId;
private List<LlmContent> contentList;

private Long traceId;


public LlmContext() {
// 创建随机的traceId
traceId = IdUtil.getSnowflakeNextId();
contentList = new ArrayList<>();
}

public LlmContext(String message) {
traceId = IdUtil.getSnowflakeNextId();
contentList = new ArrayList<>();
LlmContent ctx = new LlmContent("user",message);
contentList.add(ctx);


+ 15
- 2
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/domain/LlmParam.java Целия файл

@@ -5,13 +5,26 @@ import lombok.Data;
@Data
public class LlmParam {

// 模型选择
protected String llm;
protected Integer maxTokens;
protected Double temperature;
protected Integer topK;

// 星火大模型
protected String model;
protected String randomThreshold;

/**
* 构造函数(无参数)
* 默认为星火大模型
*/
public LlmParam() {
this.llm = "spark";
this.model = "generalv3";
this.maxTokens = 1024;
this.temperature = 0.75;
this.topK = 1;
this.temperature = 0.5;
this.topK = 4;

}
}

+ 25
- 11
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/service/impl/SparkServiceImpl.java Целия файл

@@ -1,13 +1,16 @@
package com.xueyi.nlt.nlt.service.impl;

import com.xueyi.common.core.utils.core.SpringUtils;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.netty.client.WebSocketClientManager;
import com.xueyi.nlt.netty.client.listener.LlmWebSocketListener;
import com.xueyi.nlt.netty.server.config.ServerConfig;
import com.xueyi.nlt.nlt.domain.LlmContent;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
import com.xueyi.nlt.nlt.domain.LlmResponse;
import com.xueyi.nlt.nlt.service.ISysLlmService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Primary;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
@@ -20,26 +23,35 @@ import java.util.stream.Collectors;
public class SparkServiceImpl implements ISysLlmService {

@Autowired
WebSocketClient webSocketClient;
WebSocketClientManager webSocketClientManager;

@Autowired
private StringRedisTemplate redisTemplate;

@Value("${secret.spark.appId}")
private String appId;

@Override
public LlmResponse chat(LlmContext context, LlmParam param) {
List<String> contentArr = context.getContentList().stream().map(LlmContent::getContent).collect(Collectors.toList());
WebSocketClient socketClient = SpringUtils.getBean(WebSocketClient.class);
webSocketClient = socketClient.sendMsg(contentArr);
synchronized (webSocketClient) {
LlmWebSocketListener listener = new LlmWebSocketListener(appId, param, context,false);
ServerConfig.currentTraceMap.put(context.getDevId(),context.getTraceId());
webSocketClientManager.startWebSocketClient(listener);
synchronized (listener) {
try {
webSocketClient.wait();
listener.wait();

} catch (InterruptedException e) {
e.printStackTrace();
Thread.currentThread().interrupt();
}
String result = redisTemplate.opsForValue().get("group:websocket:content");
result = webSocketClient.answer;
String result = listener.answer;
// 添加上下文

// // 添加上下文
// INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", questions.get(questions.size() - 1));
// INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", answer);
// // 添加缓存
// INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", answer);
LlmResponse response = new LlmResponse();
response.setContent(result);
return response;
@@ -50,8 +62,10 @@ public class SparkServiceImpl implements ISysLlmService {
@Override
public LlmResponse stream(LlmContext context, LlmParam param) {
List<String> contentArr = context.getContentList().stream().map(LlmContent::getContent).collect(Collectors.toList());
webSocketClient.sendMsg(contentArr, true,context.getDevId());
LlmResponse response = new LlmResponse();
LlmWebSocketListener listener = new LlmWebSocketListener(appId, param, context,true);
ServerConfig.currentTraceMap.put(context.getDevId(),context.getTraceId());
webSocketClientManager.startWebSocketClient(listener);
LlmResponse response = new LlmResponse();
return response;
}
}

+ 64
- 70
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/FreeChatTemplate.java Целия файл

@@ -1,11 +1,9 @@
package com.xueyi.nlt.nlt.template;

import com.alibaba.druid.util.StringUtils;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONException;
import com.alibaba.fastjson2.JSONObject;
import com.xueyi.common.core.context.SecurityContextHolder;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.netty.client.WebSocketClientManager;
import com.xueyi.nlt.nlt.context.TerminalSecurityContextHolder;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
@@ -49,87 +47,83 @@ public class FreeChatTemplate implements BaseTemplate{
Long operatorId = TerminalSecurityContextHolder.getOperatorId();
String redisKey = "group:nlp:" + SecurityContextHolder.getLocalMap().get("enterprise_id") + ":" + operatorId;
// 根据content内容调用模版并返回结果
synchronized (WebSocketClient.LOCK) {
// 通过redis获取数字人上下文信息
Long size = redisTemplate.opsForList().size(redisKey);
if (size > 8) {
redisTemplate.opsForList().leftPop(redisKey,2);
}
size = redisTemplate.opsForList().size(redisKey);
List<String> context = new ArrayList<>();
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人。");
context.addAll(redisTemplate.opsForList().range(redisKey,size-6,size));
// 通过redis获取数字人上下文信息
Long size = redisTemplate.opsForList().size(redisKey);
if (size > 8) {
redisTemplate.opsForList().leftPop(redisKey,2);
}
size = redisTemplate.opsForList().size(redisKey);
List<String> context = new ArrayList<>();
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人。");
context.addAll(redisTemplate.opsForList().range(redisKey,size-6,size));

context.add(content);
context.add(content);

//webSocketClient.sendMsg(context);
//webSocketClient.sendMsg(context);

LlmContext llmContext = LlmContext.parse(context,true);
LlmParam param = new LlmParam();
LlmResponse response = sysLlmService.chat(llmContext,param);
String result = response.getContent();
LlmContext llmContext = LlmContext.parse(context,true);
LlmParam param = new LlmParam();
LlmResponse response = sysLlmService.chat(llmContext,param);
String result = response.getContent();

// 处理数据
if (result.contains("我是科大讯飞")) {
result = result.replaceAll("科大讯飞", "缔智元");
}
result = result.replaceAll("认知模型", "数字员工");
result = result.replaceAll("认知智能模型", "数字员工");
if (result.equals("-1")) {
result = "这个问题超出了我无法回答,您可以提出更多关于电影相关问题。";
}
// 处理数据
if (result.contains("我是科大讯飞")) {
result = result.replaceAll("科大讯飞", "缔智元");
}
result = result.replaceAll("认知模型", "数字员工");
result = result.replaceAll("认知智能模型", "数字员工");
if (result.equals("-1")) {
result = "这个问题超出了我无法回答,您可以提出更多关于电影相关问题。";
}

if(!StringUtils.isEmpty(result)){
redisTemplate.opsForList().rightPush(redisKey,content);
redisTemplate.opsForList().rightPush(redisKey,result);
}
JSONObject resultJson = new JSONObject();
resultJson.put("msg",result);
return resultJson;
if(!StringUtils.isEmpty(result)){
redisTemplate.opsForList().rightPush(redisKey,content);
redisTemplate.opsForList().rightPush(redisKey,result);
}
JSONObject resultJson = new JSONObject();
resultJson.put("msg",result);
return resultJson;
}

public JSONObject handle(String dev, String content, boolean stream) {
Long operatorId = TerminalSecurityContextHolder.getOperatorId();
String redisKey = "group:nlp:" + SecurityContextHolder.getLocalMap().get("enterprise_id") + ":" + operatorId;
// 根据content内容调用模版并返回结果
synchronized (WebSocketClient.LOCK) {
// 通过redis获取数字人上下文信息
Long size = redisTemplate.opsForList().size(redisKey);
if (size > 8) {
redisTemplate.opsForList().leftPop(redisKey,2);
}
size = redisTemplate.opsForList().size(redisKey);
List<String> context = new ArrayList<>();
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人,你们公司在北京,你的职能是负责做会议预定和访客预约。");
// 中航信领导来访临时对策
// 判断如果content包含correctWordsMap中的key,则替换为value
for (String key : correctWordsMap.keySet()) {
if (content.contains(key)) {
content = content.replaceAll(key, correctWordsMap.get(key));
}
// 通过redis获取数字人上下文信息
Long size = redisTemplate.opsForList().size(redisKey);
if (size > 8) {
redisTemplate.opsForList().leftPop(redisKey,2);
}
size = redisTemplate.opsForList().size(redisKey);
List<String> context = new ArrayList<>();
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人,你们公司在北京,你的职能是负责做会议预定和访客预约。");
// 中航信领导来访临时对策
// 判断如果content包含correctWordsMap中的key,则替换为value
for (String key : correctWordsMap.keySet()) {
if (content.contains(key)) {
content = content.replaceAll(key, correctWordsMap.get(key));
}

//context.addAll(redisTemplate.opsForList().range(redisKey,size-2,size));
content = "请用简短的话回答下面的问题:" + content;
context.add(content);
//使用stream去除context列表中所有字符串中的引号
context = context.stream().map(s -> s.replaceAll("\"", "")).collect(java.util.stream.Collectors.toList());

//webSocketClient.sendMsg(context);

LlmContext llmContext = LlmContext.parse(context,true);
llmContext.setDevId(dev);
LlmParam param = new LlmParam();
LlmResponse response = sysLlmService.stream(llmContext,param);

JSONObject resultJson = new JSONObject();
resultJson.put("tts","让我想一想。");
resultJson.put("motion","idle");
resultJson.put("status","0");
resultJson.put("action","chat");
return resultJson;
}

//context.addAll(redisTemplate.opsForList().range(redisKey,size-2,size));
content = "请用简短的话回答下面的问题:" + content;
context.add(content);
//使用stream去除context列表中所有字符串中的引号
context = context.stream().map(s -> s.replaceAll("\"", "")).collect(java.util.stream.Collectors.toList());

//webSocketClient.sendMsg(context);

LlmContext llmContext = LlmContext.parse(context,true);
llmContext.setDevId(dev);
LlmParam param = new LlmParam();
LlmResponse response = sysLlmService.stream(llmContext,param);

JSONObject resultJson = new JSONObject();
resultJson.put("tts","让我想一想。");
resultJson.put("motion","idle");
resultJson.put("status","0");
resultJson.put("action","chat");
return resultJson;
}
@Override
public JSONObject handle(String dev, String content, Long tenantId) {


+ 25
- 28
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/GenerativeKnowledgeTemplate.java Целия файл

@@ -3,7 +3,11 @@ package com.xueyi.nlt.nlt.template;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONException;
import com.alibaba.fastjson2.JSONObject;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.netty.client.WebSocketClientManager;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
import com.xueyi.nlt.nlt.domain.LlmResponse;
import com.xueyi.nlt.nlt.service.ISysLlmService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
@@ -16,41 +20,34 @@ public class GenerativeKnowledgeTemplate implements BaseTemplate{

private static final Logger log = LoggerFactory.getLogger(GenerativeKnowledgeTemplate.class);
@Autowired
WebSocketClient webSocketClient;
WebSocketClientManager webSocketClientManager;

@Autowired
private ISysLlmService sysLlmService;

@Autowired
private RedisTemplate<String,String> redisTemplate;

@Override
public JSONObject handle(String devId, String content) {
JSONObject jsonObject = new JSONObject();
// 根据content内容调用模版并返回结果
synchronized (WebSocketClient.LOCK) {

String prefix = "你的任务是[针对给定的文段提出" + (content.length() /100 + 1 ) + "个问题并回答]。文段为:[\"";
String suffix = "\"]。输出为一个JSON数组[{}],每个元素是一个JSON:{“question”:,”answer”:}。不要给出任何解释说明。";
log.info(prefix + content + suffix);
webSocketClient.sendMsg(prefix + content + suffix);
try {
WebSocketClient.LOCK.wait();
String result = (String)redisTemplate.opsForValue().get("group:websocket:content");
try {
JSONArray jsonArray = JSONArray.parseArray(result);
JSONObject jsonObject = new JSONObject();
jsonObject.put("questions",jsonArray);
return jsonObject;
} catch (JSONException je) {
// 返回结果错误,计日志,存log,返回空结果
log.error(je.getMessage(),je);
return new JSONObject();
}

} catch (InterruptedException e) {
log.warn(e.getMessage());
Thread.currentThread().interrupt();
}
String prefix = "你的任务是[针对给定的文段提出" + (content.length() /100 + 1 ) + "个问题并回答]。文段为:[\"";
String suffix = "\"]。输出为一个JSON数组[{}],每个元素是一个JSON:{“question”:,”answer”:}。不要给出任何解释说明。";
log.info(prefix + content + suffix);
LlmContext llmContext = new LlmContext(prefix + content + suffix);
llmContext.setDevId(devId);
LlmParam llmParam = new LlmParam();
LlmResponse response = sysLlmService.chat(llmContext,llmParam);
try {
JSONArray jsonArray = JSONArray.parseArray(response.getContent());
jsonObject.put("questions",jsonArray);
return jsonObject;
} catch (JSONException je) {
// 返回结果错误,计日志,存log,返回空结果
log.error(je.getMessage(),je);
}
return new JSONObject();
return jsonObject;
}

@Override


+ 2
- 14
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/MeetingOrderTemplate.java Целия файл

@@ -2,9 +2,7 @@ package com.xueyi.nlt.nlt.template;

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.query_dsl.MatchQuery;
import co.elastic.clients.elasticsearch._types.query_dsl.Query;
import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.core.search.Hit;
import com.alibaba.cloud.nacos.NacosConfigManager;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONException;
@@ -15,14 +13,8 @@ import com.xueyi.common.core.constant.basic.SecurityConstants;
import com.xueyi.common.core.constant.digitalman.SkillConstants;
import com.xueyi.common.core.web.result.R;
import com.xueyi.nlt.api.nlt.domain.vo.CoversationSessionVo;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.nlt.config.MeetingParam;
import com.xueyi.nlt.nlt.controller.DmIntentController;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
import com.xueyi.nlt.nlt.domain.LlmResponse;
import com.xueyi.nlt.netty.client.WebSocketClientManager;
import com.xueyi.nlt.nlt.domain.vo.MeetingParamVo;
import com.xueyi.nlt.nlt.service.ISysLlmService;
import com.xueyi.system.api.digitalmans.domain.po.DmDigitalmanExtPo;
import com.xueyi.system.api.digitalmans.feign.RemoteDigitalmanService;
import com.xueyi.system.api.meeting.domain.dto.DmMeetingRoomsDto;
@@ -30,23 +22,19 @@ import com.xueyi.system.api.meeting.feign.RemoteMeetingService;
import com.xueyi.system.api.model.Source;
import com.xueyi.system.api.organize.domain.dto.SysEnterpriseDto;
import com.xueyi.system.api.organize.feign.RemoteEnterpriseService;
import com.xueyi.system.api.pass.domain.po.DmRecognizedRecordsPo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
@@ -59,7 +47,7 @@ public class MeetingOrderTemplate implements BaseTemplate {

private static final Logger log = LoggerFactory.getLogger(MeetingOrderTemplate.class);
@Autowired
WebSocketClient webSocketClient;
WebSocketClientManager webSocketClientManager;

// @Autowired
// RemoteMeetingService remoteMeetingService;


+ 2
- 2
xueyi-modules/xueyi-nlt/src/main/java/com/xueyi/nlt/nlt/template/MovieChatTemplate.java Целия файл

@@ -3,7 +3,7 @@ package com.xueyi.nlt.nlt.template;
import com.alibaba.druid.util.StringUtils;
import com.alibaba.fastjson2.JSONObject;
import com.xueyi.common.core.context.SecurityContextHolder;
import com.xueyi.nlt.netty.client.WebSocketClient;
import com.xueyi.nlt.netty.client.WebSocketClientManager;
import com.xueyi.nlt.nlt.context.TerminalSecurityContextHolder;
import com.xueyi.nlt.nlt.domain.LlmContext;
import com.xueyi.nlt.nlt.domain.LlmParam;
@@ -24,7 +24,7 @@ public class MovieChatTemplate implements BaseTemplate{

private static final Logger log = LoggerFactory.getLogger(MovieChatTemplate.class);
@Autowired
WebSocketClient webSocketClient;
WebSocketClientManager webSocketClientManager;

@Autowired
ISysLlmService sysLlmService;


Зареждане…
Отказ
Запис