@@ -1,521 +0,0 @@ | |||
package com.xueyi.nlt.netty.client; | |||
import cn.hutool.core.lang.Snowflake; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.alibaba.nacos.shaded.com.google.gson.Gson; | |||
import com.alibaba.nacos.shaded.com.google.gson.JsonArray; | |||
import com.alibaba.nacos.shaded.com.google.gson.JsonObject; | |||
import com.baomidou.mybatisplus.core.toolkit.StringUtils; | |||
import com.xueyi.common.core.utils.core.IdUtil; | |||
import com.xueyi.common.web.interceptor.ApiRequestInterceptor; | |||
import com.xueyi.nlt.api.netty.domain.vo.DmWebSocketMessageVo; | |||
import com.xueyi.nlt.netty.server.config.ServerConfig; | |||
import com.xueyi.nlt.nlt.domain.vo.LlmQueryVo; | |||
import io.netty.channel.Channel; | |||
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; | |||
import okhttp3.*; | |||
import org.apache.commons.collections4.CollectionUtils; | |||
import org.slf4j.Logger; | |||
import org.slf4j.LoggerFactory; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.beans.factory.annotation.Value; | |||
import org.springframework.data.redis.core.RedisTemplate; | |||
import org.springframework.data.redis.core.StringRedisTemplate; | |||
import org.springframework.kafka.core.KafkaTemplate; | |||
import org.springframework.stereotype.Component; | |||
import javax.annotation.PostConstruct; | |||
import javax.crypto.Mac; | |||
import javax.crypto.spec.SecretKeySpec; | |||
import java.io.Serializable; | |||
import java.net.URL; | |||
import java.nio.charset.Charset; | |||
import java.text.SimpleDateFormat; | |||
import java.util.*; | |||
@Component | |||
public class WebSocketClient extends WebSocketListener { | |||
private static final Logger logger = LoggerFactory.getLogger(WebSocketClient.class); | |||
public static WebSocketClient INSTANCE; | |||
@Autowired | |||
private RedisTemplate redisTemplate; | |||
@Autowired | |||
private StringRedisTemplate stringRedisTemplate; | |||
@Value("${secret.spark.appId}") | |||
private String appId; | |||
@Value("${secret.spark.apiSecret}") | |||
private String apiSecret; | |||
@Value("${secret.spark.apiKey}") | |||
private String apiKey; | |||
@Value("${secret.spark.hostUrl}") | |||
public String hostUrl; | |||
public final static Object LOCK = new Object(); | |||
// public static String APPID = "3d9282da";//从开放平台控制台中获取 | |||
// public static String APIKEY = "7c217b3a313f4b66fcc14a8e97f85103";//从开放平台控制台中获取 | |||
// public static String APISecret = "ZTRiNDQwMTRlOTlmZDQwMDUwYTdjMDM0";//从开放平台控制台中获取 | |||
public WebSocket webSocket; | |||
public static final Gson json = new Gson(); | |||
// public static String question = "假设你是一位前台,你需要通过与其他人对话来获取会议相关信息,已知今天是2023-7-19,你需要获取会议日期,开始时间,持续时间,会议地点,会议主题。时间用类似00:00的格式输出。对方的话中可能不包含全部信息,对于未知的信息填充为none。如果所有信息都已知那么commit为true。否则为false。将你获得的信息输出为json格式。对方的话是:“明天下午开个会。从两点开到下午三点,在大会议室开,主题是访客接待”,只输出最后的json。输出只有一行,输出格式为{date:,start_time:,duration:,location:,theme:commit:}。";//可以修改question 内容,来向模型提问 | |||
// 定义内存共享变量traceId | |||
public Long traceId; | |||
public String question = "请帮我安排五一出行计划";//可以修改question 内容,来向模型提问 | |||
public String systemRole = ""; | |||
public List<String> questions = new ArrayList<>();//可以修改question 内容,来向模型提问 | |||
public boolean stream = false; | |||
public String curUserId = null; | |||
public String answer = ""; | |||
public String answerBuf = ""; | |||
@PostConstruct | |||
public void init() { | |||
INSTANCE = this; | |||
INSTANCE.redisTemplate = this.redisTemplate; | |||
INSTANCE.stringRedisTemplate = this.stringRedisTemplate; | |||
INSTANCE.appId = this.appId; | |||
} | |||
public static void main(String[] args) { | |||
synchronized (LOCK) { | |||
try { | |||
//构建鉴权httpurl | |||
String authUrl = getAuthorizationUrl("https://spark-api.xf-yun.com/v3.1/chat", "54f6e81f40a31d66d976496de895a7a4", "ZDYyMjNmMTlkYTE0YWRmOWUwZTYxNjYz"); | |||
OkHttpClient okHttpClient = new OkHttpClient.Builder().build(); | |||
String url = authUrl.replace("https://","wss://").replace("http://","ws://"); | |||
Request request = new Request.Builder().url(url).build(); | |||
WebSocket webSocket = okHttpClient.newWebSocket(request,new WebSocketClient()); | |||
LOCK.wait(); | |||
System.out.println("查询完成"); | |||
} catch (InterruptedException ie) { | |||
ie.printStackTrace(); | |||
Thread.currentThread().interrupt(); | |||
} catch (Exception e) { | |||
e.printStackTrace(); | |||
} | |||
} | |||
// write your code here | |||
} | |||
/** | |||
* 调用讯飞开放平台接口发送消息,并将消息存入redis队列 | |||
* @param message | |||
* @param key | |||
*/ | |||
public void sendMsg(String message,String key){ | |||
LlmQueryVo vo = new LlmQueryVo(); | |||
vo.setQuestion(message); | |||
vo.setTemplate(key); | |||
redisTemplate.opsForList().rightPush("group:websocket:quary",vo); | |||
sendMsg(message); | |||
} | |||
public void sendMsg(String message){ | |||
question = message; | |||
try { | |||
//构建鉴权httpurl | |||
String authUrl = getAuthorizationUrl(hostUrl,apiKey,apiSecret); | |||
OkHttpClient okHttpClient = new OkHttpClient.Builder().build(); | |||
String url = authUrl.replace("https://","wss://").replace("http://","ws://"); | |||
Request request = new Request.Builder().url(url).build(); | |||
webSocket = okHttpClient.newWebSocket(request,new WebSocketClient()); | |||
} catch (Exception e) { | |||
e.printStackTrace(); | |||
} | |||
} | |||
public WebSocketClient sendMsg(List<String> messages){ | |||
return this.sendMsg(messages,false, "null"); | |||
} | |||
public WebSocketClient sendMsg(List<String> messages,boolean stream,String userId){ | |||
this.stream = stream; | |||
this.curUserId = userId; | |||
if (messages.size() / 2 > 0) { | |||
systemRole = messages.get(0); | |||
messages.remove(0); | |||
} | |||
questions = messages; | |||
question = null; | |||
try { | |||
//构建鉴权httpurl | |||
String authUrl = getAuthorizationUrl(hostUrl,apiKey,apiSecret); | |||
OkHttpClient okHttpClient = new OkHttpClient.Builder().build(); | |||
String url = authUrl.replace("https://","wss://").replace("http://","ws://"); | |||
Request request = new Request.Builder().url(url).build(); | |||
WebSocketClient wsc = new WebSocketClient(); | |||
wsc.stream = stream; | |||
wsc.curUserId = userId; | |||
wsc.questions = questions; | |||
wsc.question = question; | |||
wsc.systemRole = systemRole; | |||
wsc.traceId = IdUtil.getSnowflakeNextId(); | |||
ServerConfig.currentTraceMap.put(curUserId,wsc.traceId); | |||
System.out.println("wocket客户端:" + wsc.hashCode()); | |||
webSocket = okHttpClient.newWebSocket(request,wsc); | |||
return wsc; | |||
} catch (Exception e) { | |||
e.printStackTrace(); | |||
return null; | |||
} | |||
} | |||
//鉴权url | |||
public static String getAuthorizationUrl(String hostUrl , String apikey ,String apisecret) throws Exception { | |||
//获取host | |||
URL url = new URL(hostUrl); | |||
//获取鉴权时间 date | |||
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US); | |||
System.out.println("format:\n" + format ); | |||
format.setTimeZone(TimeZone.getTimeZone("GMT")); | |||
String date = format.format(new Date()); | |||
//获取signature_origin字段 | |||
StringBuilder builder = new StringBuilder("host: ").append(url.getHost()).append("\n"). | |||
append("date: ").append(date).append("\n"). | |||
append("GET ").append(url.getPath()).append(" HTTP/1.1"); | |||
System.out.println("signature_origin:\n" + builder); | |||
//获得signatue | |||
Charset charset = Charset.forName("UTF-8"); | |||
Mac mac = Mac.getInstance("hmacsha256"); | |||
SecretKeySpec sp = new SecretKeySpec(apisecret.getBytes(charset),"hmacsha256"); | |||
mac.init(sp); | |||
byte[] basebefore = mac.doFinal(builder.toString().getBytes(charset)); | |||
String signature = Base64.getEncoder().encodeToString(basebefore); | |||
//获得 authorization_origin | |||
String authorization_origin = String.format("api_key=\"%s\",algorithm=\"%s\",headers=\"%s\",signature=\"%s\"",apikey,"hmac-sha256","host date request-line",signature); | |||
//获得authorization | |||
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes(charset)); | |||
//获取httpurl | |||
HttpUrl httpUrl = HttpUrl.parse("https://" + url.getHost() + url.getPath()).newBuilder().// | |||
addQueryParameter("authorization", authorization).// | |||
addQueryParameter("date", date).// | |||
addQueryParameter("host", url.getHost()).// | |||
build(); | |||
return httpUrl.toString(); | |||
} | |||
//重写onopen | |||
@Override | |||
public void onOpen(WebSocket webSocket, Response response) { | |||
super.onOpen(webSocket, response); | |||
new Thread(()->{ | |||
JsonObject frame = new JsonObject(); | |||
JsonObject header = new JsonObject(); | |||
JsonObject chat = new JsonObject(); | |||
JsonObject parameter = new JsonObject(); | |||
JsonObject payload = new JsonObject(); | |||
JsonObject message = new JsonObject(); | |||
JsonObject text = new JsonObject(); | |||
JsonArray ja = new JsonArray(); | |||
//填充header | |||
header.addProperty("app_id",INSTANCE.appId); | |||
header.addProperty("uid","123456789"); | |||
//填充parameter | |||
// chat.addProperty("domain","general"); //1.0版本 | |||
chat.addProperty("domain","generalv3"); // 3.0版本 | |||
chat.addProperty("random_threshold",0.5); | |||
chat.addProperty("max_tokens",1024); | |||
chat.addProperty("auditing","default"); | |||
parameter.add("chat",chat); | |||
if (!StringUtils.isEmpty(systemRole)) { | |||
text = new JsonObject(); | |||
//填充payload | |||
text.addProperty("role","system"); | |||
text.addProperty("content",systemRole); | |||
ja.add(text); | |||
} | |||
if (!StringUtils.isEmpty(question)) { | |||
text = new JsonObject(); | |||
//填充payload | |||
text.addProperty("role","user"); | |||
text.addProperty("content",question); | |||
ja.add(text); | |||
}else { | |||
for (int i = 0;i < questions.size();i++) { | |||
text = new JsonObject(); | |||
if (i % 2 == 0) { | |||
text.addProperty("role","user"); | |||
} else { | |||
text.addProperty("role","assistant"); | |||
} | |||
text.addProperty("content",questions.get(i)); | |||
System.out.println(text.toString()); | |||
ja.add(text); | |||
} | |||
} | |||
// message.addProperty("text",ja.getAsString()); | |||
message.add("text",ja); | |||
payload.add("message",message); | |||
frame.add("header",header); | |||
frame.add("parameter",parameter); | |||
frame.add("payload",payload); | |||
System.out.println("frame:\n" + frame.toString()); | |||
webSocket.send(frame.toString()); | |||
} | |||
).start(); | |||
} | |||
//重写onmessage | |||
@Override | |||
public void onMessage(WebSocket webSocket, String text) { | |||
super.onMessage(webSocket, text); | |||
System.out.println("text:\n" + text); | |||
if (!StringUtils.isEmpty(curUserId)) { | |||
if (ServerConfig.currentTraceMap.containsKey(curUserId) && !ServerConfig.currentTraceMap.get(curUserId).equals(traceId)) { | |||
return; | |||
} | |||
} | |||
synchronized (this) { | |||
ResponseData responseData = json.fromJson(text,ResponseData.class); | |||
try { | |||
// System.out.println("code:\n" + responseData.getHeader().get("code")); | |||
if (0 == responseData.getHeader().get("code").getAsInt()) { | |||
System.out.println("###########"); | |||
System.out.println("getStatus: " + responseData.getHeader().get("status").getAsInt()); | |||
if (2 != responseData.getHeader().get("status").getAsInt()) { | |||
System.out.println("****************"); | |||
Payload pl = json.fromJson(responseData.getPayload(), Payload.class); | |||
JsonArray temp = (JsonArray) pl.getChoices().get("text"); | |||
JsonObject jo = (JsonObject) temp.get(0); | |||
answer += jo.get("content").getAsString(); | |||
answerBuf += jo.get("content").getAsString(); | |||
// System.out.println(answer); | |||
} else { | |||
Payload pl1 = json.fromJson(responseData.getPayload(), Payload.class); | |||
JsonObject jsonObject = (JsonObject) pl1.getUsage().get("text"); | |||
int prompt_tokens = jsonObject.get("prompt_tokens").getAsInt(); | |||
JsonArray temp1 = (JsonArray) pl1.getChoices().get("text"); | |||
JsonObject jo = (JsonObject) temp1.get(0); | |||
answer += jo.get("content").getAsString(); | |||
answerBuf += jo.get("content").getAsString(); | |||
System.out.println("返回结果为:\n" + answer); | |||
if (INSTANCE.redisTemplate.hasKey("gpt:websocket:1")) { | |||
DmWebSocketMessageVo message = (DmWebSocketMessageVo) INSTANCE.redisTemplate.opsForValue().get("gpt:websocket:1"); | |||
if (message != null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) { | |||
JSONObject birthdayJo = new JSONObject(); | |||
birthdayJo.put("content", answer); | |||
birthdayJo.put("timestamp", message.getFormat().get("timestamp")); | |||
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "birthday", birthdayJo.toString()); | |||
INSTANCE.redisTemplate.delete("gpt:websocket:1"); | |||
this.notifyAll(); | |||
return; | |||
} | |||
if (message!= null && StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hireDate")) { | |||
JSONObject birthdayJo = new JSONObject(); | |||
birthdayJo.put("content", answer); | |||
birthdayJo.put("timestamp", message.getFormat().get("timestamp")); | |||
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + message.getFormat().getString("orderId"), "hireDate", birthdayJo.toString()); | |||
INSTANCE.redisTemplate.delete("gpt:websocket:1"); | |||
this.notifyAll(); | |||
return; | |||
} | |||
if (message != null) { | |||
JSONObject preWebsocketJo = message.getFormat(); | |||
JSONObject meetingJo = new JSONObject(); | |||
meetingJo.put("timestamp",preWebsocketJo.get("timestamp")); | |||
meetingJo.put("content",answer); | |||
INSTANCE.stringRedisTemplate.opsForHash().put("group:nlp" + ":" + preWebsocketJo.getString("orderId"), "meeting", meetingJo.toString()); | |||
this.notifyAll(); | |||
return; | |||
} | |||
INSTANCE.redisTemplate.delete("gpt:websocket:1"); | |||
// 清除systemRole | |||
systemRole = ""; | |||
}else { | |||
// 添加上下文 | |||
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", questions.get(questions.size() - 1)); | |||
INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", answer); | |||
// 添加缓存 | |||
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", answer); | |||
} | |||
this.notifyAll(); | |||
// webSocket.close(3,"客户端主动断开链接"); | |||
//webSocket.close(1000,"客户端主动断开链接"); | |||
} | |||
if (this.stream && !StringUtils.isEmpty(curUserId)) { | |||
Channel ch = ServerConfig.sessionMap.get(curUserId); | |||
logger.info("当前ch:{}",ch.id().asLongText()); | |||
if (ch != null) { | |||
List<String> ttsList = new ArrayList<>(); | |||
JSONObject jo = new JSONObject(); | |||
jo.put("action","chat"); | |||
jo.put("motion","idle"); | |||
jo.put("traceId",traceId); | |||
//去除转义符 | |||
answerBuf = answerBuf.replaceAll("[\\r\\n]", ""); | |||
//去除引号 | |||
answerBuf = answerBuf.replaceAll("\"", ""); | |||
// 处理answer,如果包含"。",则将"。"之前的内容发送给前端 | |||
while(answerBuf.contains("。") || answerBuf.contains("?") || answerBuf.contains("!") || | |||
answerBuf.contains("?") || answerBuf.contains("!")) { | |||
String[] temp = answerBuf.split("。|?|!|\\?|\\!"); | |||
ttsList.add(temp[0] + answerBuf.charAt(temp[0].length())); | |||
answerBuf = answerBuf.substring(temp[0].length() + 1); | |||
} | |||
if (2 == responseData.getHeader().get("status").getAsInt() && CollectionUtils.isEmpty(ttsList)) { | |||
jo.put("tts",answerBuf); | |||
jo.put("status",2); | |||
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString()); | |||
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString())); | |||
} else { | |||
for (int i = 0;i <ttsList.size();i++) { | |||
if (2 == responseData.getHeader().get("status").getAsInt() && i == ttsList.size() - 1) { | |||
jo.put("status",2); | |||
} else { | |||
jo.put("status",1); | |||
} | |||
jo.put("tts",ttsList.get(i)); | |||
String str = jo.toJSONString(); | |||
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString()); | |||
ch.writeAndFlush(new TextWebSocketFrame(str)); | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
// 添加缓存 | |||
INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", "-1"); | |||
// 判断流式则返回结束状态 | |||
if (stream == true) { | |||
Channel ch = ServerConfig.sessionMap.get(curUserId); | |||
if (ch != null) { | |||
JSONObject jo = new JSONObject(); | |||
jo.put("action","chat"); | |||
jo.put("motion","idle"); | |||
jo.put("traceId",traceId); | |||
jo.put("status",2); | |||
jo.put("tts","抱歉,您的问题我无法解答。"); | |||
String str = jo.toJSONString(); | |||
logger.info("发送到client:{},id:{},内容:{}",curUserId,ch.id().asLongText(),jo.toJSONString()); | |||
ch.writeAndFlush(new TextWebSocketFrame(str)); | |||
} | |||
} | |||
System.out.println("返回结果错误:\n" + responseData.getHeader().get("code") + responseData.getHeader().get("message")); | |||
this.notifyAll(); | |||
} | |||
} catch (Exception e) { | |||
if (StringUtils.isNotEmpty(curUserId)) { | |||
Channel ch = ServerConfig.sessionMap.get(curUserId); | |||
JSONObject jo = new JSONObject(); | |||
jo.put("action","chat"); | |||
jo.put("motion","idle"); | |||
jo.put("traceId",traceId); | |||
jo.put("status",2); | |||
jo.put("tts","大模型出现异常,请稍后重试。"); | |||
String str = jo.toJSONString(); | |||
logger.info("发生异常client:{},内容:{}",curUserId,jo.toJSONString()); | |||
ch.writeAndFlush(new TextWebSocketFrame(str)); | |||
} | |||
e.printStackTrace(); | |||
this.notifyAll(); | |||
} | |||
} | |||
} | |||
//重写onFailure | |||
@Override | |||
public void onFailure(WebSocket webSocket, Throwable t, Response response) { | |||
super.onFailure(webSocket, t, response); | |||
System.out.println(response); | |||
} | |||
class ResponseData{ | |||
private JsonObject header; | |||
private JsonObject payload; | |||
public JsonObject getHeader() { | |||
return header; | |||
} | |||
public JsonObject getPayload() { | |||
return payload; | |||
} | |||
} | |||
class Header{ | |||
private int code ; | |||
private String message; | |||
private String sid; | |||
private String status; | |||
public int getCode() { | |||
return code; | |||
} | |||
public String getMessage() { | |||
return message; | |||
} | |||
public String getSid() { | |||
return sid; | |||
} | |||
public String getStatus() { | |||
return status; | |||
} | |||
} | |||
class Payload{ | |||
private JsonObject choices; | |||
private JsonObject usage; | |||
public JsonObject getChoices() { | |||
return choices; | |||
} | |||
public JsonObject getUsage() { | |||
return usage; | |||
} | |||
} | |||
class Choices{ | |||
private int status; | |||
private int seq; | |||
private JsonArray text; | |||
public int getStatus() { | |||
return status; | |||
} | |||
public int getSeq() { | |||
return seq; | |||
} | |||
public JsonArray getText() { | |||
return text; | |||
} | |||
} | |||
} |
@@ -0,0 +1,131 @@ | |||
package com.xueyi.nlt.netty.client; | |||
import com.xueyi.nlt.netty.client.listener.LlmWebSocketListener; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
import okhttp3.*; | |||
import org.slf4j.Logger; | |||
import org.slf4j.LoggerFactory; | |||
import org.springframework.beans.factory.annotation.Value; | |||
import org.springframework.stereotype.Component; | |||
import javax.annotation.PostConstruct; | |||
import javax.annotation.PreDestroy; | |||
import javax.crypto.Mac; | |||
import javax.crypto.spec.SecretKeySpec; | |||
import java.net.URL; | |||
import java.nio.charset.Charset; | |||
import java.text.SimpleDateFormat; | |||
import java.util.*; | |||
import java.util.concurrent.ExecutorService; | |||
import java.util.concurrent.LinkedBlockingQueue; | |||
import java.util.concurrent.ThreadPoolExecutor; | |||
import java.util.concurrent.TimeUnit; | |||
@Component | |||
public class WebSocketClientManager { | |||
private static final Logger logger = LoggerFactory.getLogger(WebSocketClientManager.class); | |||
private static final int MAX_CONCURRENT_CONNECTIONS = 5; | |||
private static final int CORE_POOL_SIZE = 2; | |||
private static final int MAX_POOL_SIZE = 4; | |||
private static final long KEEP_ALIVE_TIME = 60L; | |||
private ExecutorService connectionPool; | |||
private OkHttpClient client; | |||
@Value("${secret.spark.apiSecret}") | |||
private String apiSecret; | |||
@Value("${secret.spark.apiKey}") | |||
private String apiKey; | |||
@Value("${secret.spark.hostUrl}") | |||
public String hostUrl; | |||
@PostConstruct | |||
public void init() { | |||
ConnectionPool okHttpConnectionPool = new ConnectionPool(MAX_CONCURRENT_CONNECTIONS, KEEP_ALIVE_TIME, TimeUnit.SECONDS); | |||
client = new OkHttpClient.Builder().connectionPool(okHttpConnectionPool).build(); | |||
connectionPool = new ThreadPoolExecutor(CORE_POOL_SIZE, MAX_POOL_SIZE, KEEP_ALIVE_TIME, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); | |||
} | |||
public static void main(String[] args) { | |||
LlmContext context = new LlmContext("今天北京天气怎么样"); | |||
LlmParam param = new LlmParam(); | |||
LlmWebSocketListener listener = new LlmWebSocketListener("12345", param, context,false); | |||
synchronized (listener) { | |||
try { | |||
//构建鉴权httpurl | |||
String authUrl = getAuthorizationUrl("https://spark-api.xf-yun.com/v3.1/chat", "54f6e81f40a31d66d976496de895a7a4", "ZDYyMjNmMTlkYTE0YWRmOWUwZTYxNjYz"); | |||
OkHttpClient okHttpClient = new OkHttpClient.Builder().build(); | |||
String url = authUrl.replace("https://","wss://").replace("http://","ws://"); | |||
Request request = new Request.Builder().url(url).build(); | |||
WebSocket webSocket = okHttpClient.newWebSocket(request,listener); | |||
listener.wait(); | |||
System.out.println("查询完成"); | |||
} catch (InterruptedException ie) { | |||
ie.printStackTrace(); | |||
Thread.currentThread().interrupt(); | |||
} catch (Exception e) { | |||
e.printStackTrace(); | |||
} | |||
} | |||
} | |||
/** | |||
* 调用讯飞开放平台接口发送消息 | |||
* @param listener | |||
*/ | |||
public void startWebSocketClient(LlmWebSocketListener listener) { | |||
connectionPool.execute(() -> { | |||
try { | |||
String authUrl = getAuthorizationUrl(hostUrl, apiKey, apiSecret); | |||
String url = authUrl.replace("https://", "wss://").replace("http://", "ws://"); | |||
Request request = new Request.Builder().url(url).build(); | |||
WebSocket webSocket = client.newWebSocket(request, listener); | |||
} catch (Exception e) { | |||
logger.error("startWebSocketClient error", e.getStackTrace()); | |||
} | |||
}); | |||
} | |||
//鉴权url | |||
public static String getAuthorizationUrl(String hostUrl , String apikey ,String apisecret) throws Exception { | |||
//获取host | |||
URL url = new URL(hostUrl); | |||
//获取鉴权时间 date | |||
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US); | |||
System.out.println("format:\n" + format ); | |||
format.setTimeZone(TimeZone.getTimeZone("GMT")); | |||
String date = format.format(new Date()); | |||
//获取signature_origin字段 | |||
StringBuilder builder = new StringBuilder("host: ").append(url.getHost()).append("\n"). | |||
append("date: ").append(date).append("\n"). | |||
append("GET ").append(url.getPath()).append(" HTTP/1.1"); | |||
System.out.println("signature_origin:\n" + builder); | |||
//获得signatue | |||
Charset charset = Charset.forName("UTF-8"); | |||
Mac mac = Mac.getInstance("hmacsha256"); | |||
SecretKeySpec sp = new SecretKeySpec(apisecret.getBytes(charset),"hmacsha256"); | |||
mac.init(sp); | |||
byte[] basebefore = mac.doFinal(builder.toString().getBytes(charset)); | |||
String signature = Base64.getEncoder().encodeToString(basebefore); | |||
//获得 authorization_origin | |||
String authorization_origin = String.format("api_key=\"%s\",algorithm=\"%s\",headers=\"%s\",signature=\"%s\"",apikey,"hmac-sha256","host date request-line",signature); | |||
//获得authorization | |||
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes(charset)); | |||
//获取httpurl | |||
HttpUrl httpUrl = HttpUrl.parse("https://" + url.getHost() + url.getPath()).newBuilder().// | |||
addQueryParameter("authorization", authorization).// | |||
addQueryParameter("date", date).// | |||
addQueryParameter("host", url.getHost()).// | |||
build(); | |||
return httpUrl.toString(); | |||
} | |||
} |
@@ -0,0 +1,325 @@ | |||
package com.xueyi.nlt.netty.client.listener; | |||
import com.alibaba.fastjson2.JSON; | |||
import com.alibaba.fastjson2.JSONArray; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.alibaba.nacos.shaded.com.google.gson.Gson; | |||
import com.alibaba.nacos.shaded.com.google.gson.JsonArray; | |||
import com.alibaba.nacos.shaded.com.google.gson.JsonElement; | |||
import com.alibaba.nacos.shaded.com.google.gson.JsonObject; | |||
import com.baomidou.mybatisplus.core.toolkit.StringUtils; | |||
import com.xueyi.nlt.netty.server.config.ServerConfig; | |||
import com.xueyi.nlt.nlt.domain.LlmContent; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
import io.netty.channel.Channel; | |||
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; | |||
import okhttp3.Response; | |||
import okhttp3.WebSocket; | |||
import okhttp3.WebSocketListener; | |||
import okio.ByteString; | |||
import org.slf4j.Logger; | |||
import org.slf4j.LoggerFactory; | |||
import javax.annotation.Nullable; | |||
import java.util.ArrayList; | |||
import java.util.List; | |||
public class LlmWebSocketListener extends WebSocketListener { | |||
private static final Logger logger = LoggerFactory.getLogger(LlmWebSocketListener.class); | |||
private static final Integer SPARK_RESPONSE_STATUS_FIRST_RESULT = 0; | |||
private static final Integer SPARK_RESPONSE_STATUS_INTERMEDIATE_RESULT = 1; | |||
private static final Integer SPARK_RESPONSE_STATUS_LAST_RESULT = 2; | |||
private static final Double RANDOM_THRESHOLD = 0.5; | |||
protected String appId; | |||
protected LlmParam llmParam; | |||
protected LlmContext llmContext; | |||
public String systemRole = ""; | |||
// 是否为流式 | |||
protected boolean stream = false; | |||
// 是否关闭websocket | |||
private Boolean wsCloseFlag = false; | |||
public String question = "";//可以修改question 内容,来向模型提问 | |||
public List<String> questions = new ArrayList<>();//可以修改question 内容,来向模型提问 | |||
public String answer = ""; | |||
public String answerBuf = ""; | |||
public static final Gson json = new Gson(); | |||
public LlmWebSocketListener(String appId, LlmParam llmParam, LlmContext context, boolean stream) { | |||
this.appId = appId; | |||
this.llmParam = llmParam; | |||
this.llmContext = context; | |||
this.stream = stream; | |||
} | |||
@Override | |||
public void onOpen(WebSocket webSocket, Response response) { | |||
super.onOpen(webSocket, response); | |||
new Thread(()->{ | |||
try { | |||
JsonObject frame = new JsonObject(); | |||
JsonObject header = new JsonObject(); | |||
JsonObject chat = new JsonObject(); | |||
JsonObject parameter = new JsonObject(); | |||
JsonObject payload = new JsonObject(); | |||
JsonObject message = new JsonObject(); | |||
JsonObject text = new JsonObject(); | |||
JsonArray ja = new JsonArray(); | |||
//填充header | |||
header.addProperty("app_id",appId); | |||
header.addProperty("uid","123456789"); | |||
//填充parameter | |||
// chat.addProperty("domain","general"); //1.0版本 | |||
chat.addProperty("domain",llmParam.getModel()); // 3.0版本 | |||
chat.addProperty("temperature",llmParam.getTemperature()); | |||
chat.addProperty("max_tokens",llmParam.getMaxTokens()); | |||
parameter.add("chat",chat); | |||
// 插入大模型上下文 | |||
for (LlmContent llmContent : llmContext.getContentList()) { | |||
JsonElement jsonTree = json.toJsonTree(llmContent, LlmContent.class); | |||
ja.add(jsonTree); | |||
} | |||
// message.addProperty("text",ja.getAsString()); | |||
message.add("text",ja); | |||
payload.add("message",message); | |||
frame.add("header",header); | |||
frame.add("parameter",parameter); | |||
frame.add("payload",payload); | |||
System.out.println("frame:\n" + frame.toString()); | |||
webSocket.send(frame.toString()); | |||
// 等待服务端返回完毕后关闭 | |||
while (true) { | |||
// System.err.println(wsCloseFlag + "---"); | |||
Thread.sleep(200); | |||
if (wsCloseFlag) { | |||
break; | |||
} | |||
} | |||
webSocket.close(1000, ""); | |||
} catch (InterruptedException e) { | |||
logger.error("websocket发送消息异常:{}",e.getMessage()); | |||
} | |||
} | |||
).start(); | |||
} | |||
@Override | |||
public void onMessage(WebSocket webSocket, String text) { | |||
super.onMessage(webSocket, text); | |||
System.out.println("text:\n" + text); | |||
if (ServerConfig.currentTraceMap.containsKey(llmContext.getDevId()) && !ServerConfig.currentTraceMap.get(llmContext.getDevId()).equals(llmContext.getTraceId())) { | |||
return; | |||
} | |||
ResponseData responseData = json.fromJson(text, ResponseData.class); | |||
synchronized (this) { | |||
// 如果返回的code不等于0,打印错误日志,设置websocket关闭标志位,设置answer为抱歉,您的问题我无法解答。如果为流式调用则向channel发送错误信息 | |||
if (0 != responseData.header.code) { | |||
logger.error("返回结果错误:{}" , responseData); | |||
this.wsCloseFlag = true; | |||
this.answer = "抱歉,您的问题我无法解答。"; | |||
if (this.stream) { | |||
Channel ch = ServerConfig.sessionMap.get(llmContext.getDevId()); | |||
if (ch != null) { | |||
JSONObject jo = formatToChannel("抱歉,您的问题我无法解答。",SPARK_RESPONSE_STATUS_LAST_RESULT,responseData.header.code); | |||
logger.info("发送到client:{},id:{},内容:{}",llmContext.getDevId(),ch.id().asLongText(),jo.toJSONString()); | |||
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString())); | |||
} | |||
} else { | |||
this.notifyAll(); | |||
} | |||
return; | |||
} | |||
try { | |||
System.out.println("###########"); | |||
System.out.println("getStatus: " + responseData.header.status); | |||
List<Text> textList = responseData.payload.choices.text; | |||
for (Text temp : textList) { | |||
answer += temp.content; | |||
answerBuf += temp.content; | |||
} | |||
if (stream) { | |||
Channel ch = ServerConfig.sessionMap.get(llmContext.getDevId()); | |||
logger.info("当前ch:{}",ch.id().asLongText()); | |||
if (ch != null) { | |||
// 向数字人推送流式消息,将answerBuf中的内容做拆分,如果包含"。",则将"。"之前的内容以列表形式发送给前端 | |||
List<String> ttsList = new ArrayList<>(); | |||
//去除转义符 | |||
answerBuf = answerBuf.replaceAll("[\\r\\n]", ""); | |||
//去除引号 | |||
answerBuf = answerBuf.replaceAll("\"", ""); | |||
// 处理answer,如果包含"。",则将"。"之前的内容发送给前端 | |||
while(answerBuf.contains("。") || answerBuf.contains("?") || answerBuf.contains("!") || | |||
answerBuf.contains("?") || answerBuf.contains("!")) { | |||
String[] temp = answerBuf.split("。|?|!|\\?|\\!"); | |||
ttsList.add(temp[0] + answerBuf.charAt(temp[0].length())); | |||
answerBuf = answerBuf.substring(temp[0].length() + 1); | |||
} | |||
for (int i = 0;i < ttsList.size();i++) { | |||
JSONObject resJo; | |||
if (i < ttsList.size() - 1) { | |||
resJo = formatToChannel(ttsList.get(i),1,responseData.header.code); | |||
} else { | |||
resJo = formatToChannel(ttsList.get(i),responseData.header.status,responseData.header.code); | |||
} | |||
logger.info("发送到client:{},id:{},内容:{}",llmContext.getDevId(),ch.id().asLongText(),resJo.toJSONString()); | |||
ch.writeAndFlush(new TextWebSocketFrame(resJo.toJSONString())); | |||
} | |||
} | |||
} | |||
// 判断当前状态是否为发送完成,如果为发送完成,则打印answer,设置websocket关闭标志位,如果为非流式调用,释放锁。 | |||
if (SPARK_RESPONSE_STATUS_LAST_RESULT == responseData.header.status) { | |||
System.out.println("返回结果为:\n" + answer); | |||
//关闭释放资源 | |||
this.wsCloseFlag = true; | |||
// 如果是非流式调用,释放锁 | |||
if (!stream) { | |||
this.notifyAll(); | |||
} | |||
} | |||
} catch (Exception e) { | |||
logger.error("返回结果错误:{}" , e.getMessage()); | |||
this.answer = "大模型出现异常,请稍后重试。"; | |||
if (stream) { | |||
Channel ch = ServerConfig.sessionMap.get(llmContext.getDevId()); | |||
if (ch != null) { | |||
JSONObject jo = formatToChannel("大模型出现异常,请稍后重试。",SPARK_RESPONSE_STATUS_LAST_RESULT,-1); | |||
logger.info("发生异常client:{},内容:{}",llmContext.getDevId(),jo.toJSONString()); | |||
ch.writeAndFlush(new TextWebSocketFrame(jo.toJSONString())); | |||
} | |||
} else { | |||
this.notifyAll(); | |||
} | |||
} | |||
} | |||
} | |||
@Override | |||
public void onMessage(WebSocket webSocket, ByteString bytes) { | |||
super.onMessage(webSocket, bytes); | |||
} | |||
@Override | |||
public void onClosing(WebSocket webSocket, int code, String reason) { | |||
super.onClosing(webSocket, code, reason); | |||
} | |||
@Override | |||
public void onClosed(WebSocket webSocket, int code, String reason) { | |||
super.onClosed(webSocket, code, reason); | |||
} | |||
@Override | |||
public void onFailure(WebSocket webSocket, Throwable t, @Nullable Response response) { | |||
super.onFailure(webSocket, t, response); | |||
} | |||
// 定义一个局部方法,从SessionMap中获取获取当前用户的channel,向channel发送JSON格式的Response | |||
public JSONObject formatToChannel(String msg, Integer status,Integer code) { | |||
JSONObject jo = new JSONObject(); | |||
jo.put("action","chat"); | |||
jo.put("motion","idle"); | |||
jo.put("traceId",llmContext.getTraceId()); | |||
jo.put("status",status); | |||
jo.put("tts",msg); | |||
jo.put("code",code); | |||
return jo; | |||
} | |||
class ResponseData{ | |||
private Header header; | |||
private Payload payload; | |||
public Header getHeader() { | |||
return header; | |||
} | |||
public Payload getPayload() { | |||
return payload; | |||
} | |||
} | |||
class Header{ | |||
private int code ; | |||
private String message; | |||
private String sid; | |||
private int status; | |||
public int getCode() { | |||
return code; | |||
} | |||
public String getMessage() { | |||
return message; | |||
} | |||
public String getSid() { | |||
return sid; | |||
} | |||
public int getStatus() { | |||
return status; | |||
} | |||
} | |||
class Payload{ | |||
private Choices choices; | |||
private Useage usage; | |||
public Choices getChoices() { | |||
return choices; | |||
} | |||
public Useage getUsage() { | |||
return usage; | |||
} | |||
} | |||
class Choices{ | |||
int status; | |||
int seq; | |||
List<Text> text; | |||
public int getStatus() { | |||
return status; | |||
} | |||
public int getSeq() { | |||
return seq; | |||
} | |||
public List<Text> getText() { | |||
return text; | |||
} | |||
} | |||
class Text { | |||
String role; | |||
String content; | |||
} | |||
class Useage{ | |||
/** 保留字段,可忽略 */ | |||
Integer question_tokens; | |||
/** 包含历史问题的总tokens大小 */ | |||
Integer prompt_tokens; | |||
/** 回答的tokens大小 */ | |||
Integer completion_tokens; | |||
/** prompt_tokens和completion_tokens的和,也是本次交互计费的tokens大小 */ | |||
Integer total_tokens; | |||
} | |||
} |
@@ -2,16 +2,9 @@ package com.xueyi.nlt.netty.controller; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.baomidou.mybatisplus.core.toolkit.StringUtils; | |||
import com.xueyi.common.cache.utils.SourceUtil; | |||
import com.xueyi.common.core.constant.basic.SecurityConstants; | |||
import com.xueyi.common.core.web.result.AjaxResult; | |||
import com.xueyi.common.core.web.result.R; | |||
import com.xueyi.nlt.api.netty.domain.vo.DmWebSocketMessageVo; | |||
import com.xueyi.nlt.api.nlt.domain.vo.DmIntentVo; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import com.xueyi.nlt.nlt.domain.vo.IntentTemplateVo; | |||
import com.xueyi.system.api.digitalmans.domain.dto.DmManDeviceDto; | |||
import com.xueyi.system.api.model.Source; | |||
import com.xueyi.nlt.netty.client.WebSocketClientManager; | |||
import org.slf4j.Logger; | |||
import org.slf4j.LoggerFactory; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
@@ -29,7 +22,7 @@ public class DmWebsocketController { | |||
private static final Logger log = LoggerFactory.getLogger(DmWebsocketController.class); | |||
@Autowired | |||
WebSocketClient webSocketClient; | |||
WebSocketClientManager webSocketClientManager; | |||
@Autowired | |||
StringRedisTemplate stringRedisTemplate; | |||
@@ -54,26 +47,6 @@ public class DmWebsocketController { | |||
SimpleDateFormat dateFormat3 = new SimpleDateFormat("MM-dd"); | |||
Double timestamp = Double.valueOf((String)jo.get("timestamp")); | |||
String meetingRoom = jo.getString("meetingRoom"); | |||
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) { | |||
String prefix = "假设你是一名公司前台,你看到" + message.getFormat().get("name")+ "已知今天是他的生日。请你从个人角度输出给他的生日贺词。要求待人平和,具有人情味,用词正式,内容与工作无关。输出只包含你要对他说的话,在20字以内。"; | |||
webSocketClient.sendMsg(prefix); | |||
redisTemplate.opsForValue().set("gpt:websocket" + ":" + "1", message); | |||
return R.ok(); | |||
} | |||
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hiredate")) { | |||
String prefix = "假设你是一名公司前台,你看到"+ message.getFormat().get("name")+ ",已知今天是他入职1周年,请你从个人角度说出对他入职周年的祝贺。要求具有人情味,有特色,字数在25字左右,不要提到生日,要带人名。输出只包含你要对他说的话。"; | |||
webSocketClient.sendMsg(prefix); | |||
redisTemplate.opsForValue().set("gpt:websocket" + ":" + "1", message); | |||
return R.ok(); | |||
} | |||
Date date = new Date(timestamp.longValue()); | |||
if (message.getSkillCode().equals("1")) { | |||
String prefix = "假设你是一名公司前台,你看到在你们公司工作的\\"+ jo.getString("orderName")+ "\\,请你从个人角度提醒他参加\\" + | |||
dateFormat3.format(timestamp) + "\\在\\" + meetingRoom + "\\的会,要求语气友好。输出只包含你要对他说的话,在20字左右。"; | |||
webSocketClient.sendMsg(prefix); | |||
// 设置缓存 | |||
redisTemplate.opsForValue().set("gpt:websocket" + ":" + "1", message); | |||
} | |||
return R.ok(); | |||
} | |||
} |
@@ -1,7 +1,5 @@ | |||
package com.xueyi.nlt.netty.server.config; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import io.netty.channel.Channel; | |||
import java.util.concurrent.ConcurrentHashMap; | |||
@@ -11,7 +11,6 @@ import com.xueyi.common.core.constant.digitalman.SkillConstants.SkillType; | |||
import com.xueyi.common.core.context.SecurityContextHolder; | |||
import com.xueyi.common.core.utils.DateUtil; | |||
import com.xueyi.common.core.utils.core.IdUtil; | |||
import com.xueyi.common.core.utils.core.SpringUtils; | |||
import com.xueyi.common.core.web.result.AjaxResult; | |||
import com.xueyi.common.core.web.result.R; | |||
import com.xueyi.common.core.web.validate.V_A; | |||
@@ -33,7 +32,7 @@ import com.xueyi.nlt.api.nlt.domain.vo.response.DmKnowledgeResponse; | |||
import com.xueyi.nlt.api.nlt.feign.RemoteIntentService; | |||
import com.xueyi.nlt.api.nlt.feign.RemoteLandingLlmService; | |||
import com.xueyi.nlt.api.nlt.feign.RemoteQAService; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import com.xueyi.nlt.netty.client.WebSocketClientManager; | |||
import com.xueyi.nlt.nlt.context.TerminalSecurityContextHolder; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
@@ -54,8 +53,6 @@ import com.xueyi.system.api.digitalmans.feign.RemoteDigitalmanService; | |||
import com.xueyi.system.api.digitalmans.feign.RemoteManDeviceService; | |||
import com.xueyi.system.api.digitalmans.feign.RemoteQuestionanswersService; | |||
import com.xueyi.system.api.digitalmans.feign.RemoteSkillService; | |||
import com.xueyi.system.api.interfaces.airport.domain.vo.PlaneMessageVo; | |||
import com.xueyi.system.api.interfaces.airport.feign.RemotePlaneController; | |||
import com.xueyi.system.api.model.Source; | |||
import com.xueyi.system.api.organize.domain.dto.SysEnterpriseDto; | |||
import com.xueyi.system.api.organize.feign.RemoteEnterpriseService; | |||
@@ -64,7 +61,6 @@ import org.slf4j.LoggerFactory; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.data.redis.core.RedisTemplate; | |||
import org.springframework.data.redis.core.StringRedisTemplate; | |||
import org.springframework.util.DigestUtils; | |||
import org.springframework.validation.annotation.Validated; | |||
import org.springframework.web.bind.annotation.DeleteMapping; | |||
import org.springframework.web.bind.annotation.GetMapping; | |||
@@ -77,9 +73,7 @@ import org.springframework.web.bind.annotation.ResponseBody; | |||
import org.springframework.web.bind.annotation.RestController; | |||
import java.io.Serializable; | |||
import java.text.DateFormat; | |||
import java.text.ParseException; | |||
import java.text.SimpleDateFormat; | |||
import java.time.format.DateTimeFormatter; | |||
import java.util.Arrays; | |||
import java.util.Date; | |||
import java.util.List; | |||
@@ -99,7 +93,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt | |||
IDmIntentService dmIntentService; | |||
@Autowired | |||
WebSocketClient webSocketClient; | |||
WebSocketClientManager webSocketClientManager; | |||
/** 定义节点名称 */ | |||
@Override | |||
@@ -598,6 +592,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt | |||
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("birthday")) { | |||
String prefix = "假设你是一名公司前台,你看到" + message.getFormat().get("name")+ "已知今天是他的生日。请你从个人角度输出给他的生日贺词。要求待人平和,具有人情味,用词正式,内容与工作无关。输出只包含你要对他说的话,在20字以内。"; | |||
LlmContext context = new LlmContext(prefix); | |||
context.setDevId(message.getDevId()); | |||
LlmResponse response = sysLlmService.chat(context, new LlmParam()); | |||
JSONObject birthdayJo = new JSONObject(); | |||
birthdayJo.put("content", response.getContent()); | |||
@@ -608,6 +603,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt | |||
if (StringUtils.isNotEmpty(message.getTemplate()) && message.getTemplate().equals("hireDate")) { | |||
String prefix = "假设你是一名公司前台,你看到"+ message.getFormat().get("name")+ ",已知今天是他入职" + message.getFormat().get("years")+"周年,请你从个人角度说出对他入职周年的祝贺。要求具有人情味,有特色,字数在25字左右,不要提到生日,要带人名。输出只包含你要对他说的话。"; | |||
LlmContext context = new LlmContext(prefix); | |||
context.setDevId(message.getDevId()); | |||
LlmResponse response = sysLlmService.chat(context, new LlmParam()); | |||
JSONObject hireDateJo = new JSONObject(); | |||
hireDateJo.put("content", response.getContent()); | |||
@@ -620,6 +616,7 @@ public class DmIntentController extends BaseController<DmIntentQuery, DmIntentDt | |||
String prefix = "假设你是一名公司前台,你看到在你们公司工作的\\"+ jo.getString("orderName")+ "\\,请你从个人角度提醒他参加\\" + | |||
dateFormat4.format(timestamp) + "\\在\\" + meetingRoom + "\\的会,要求语气友好。输出只包含你要对他说的话,在20字左右。"; | |||
LlmContext context = new LlmContext(prefix); | |||
context.setDevId(message.getDevId()); | |||
LlmResponse response = sysLlmService.chat(context, new LlmParam()); | |||
JSONObject meetingJo = new JSONObject(); | |||
meetingJo.put("content", response.getContent()); | |||
@@ -1,6 +1,7 @@ | |||
package com.xueyi.nlt.nlt.domain; | |||
import com.alibaba.fastjson2.JSONArray; | |||
import com.xueyi.common.core.utils.core.IdUtil; | |||
import lombok.Data; | |||
import lombok.NoArgsConstructor; | |||
@@ -10,14 +11,22 @@ import java.util.List; | |||
import java.util.stream.Collectors; | |||
@Data | |||
@NoArgsConstructor | |||
public class LlmContext implements Serializable { | |||
private String devId; | |||
private List<LlmContent> contentList; | |||
private Long traceId; | |||
public LlmContext() { | |||
// 创建随机的traceId | |||
traceId = IdUtil.getSnowflakeNextId(); | |||
contentList = new ArrayList<>(); | |||
} | |||
public LlmContext(String message) { | |||
traceId = IdUtil.getSnowflakeNextId(); | |||
contentList = new ArrayList<>(); | |||
LlmContent ctx = new LlmContent("user",message); | |||
contentList.add(ctx); | |||
@@ -5,13 +5,26 @@ import lombok.Data; | |||
@Data | |||
public class LlmParam { | |||
// 模型选择 | |||
protected String llm; | |||
protected Integer maxTokens; | |||
protected Double temperature; | |||
protected Integer topK; | |||
// 星火大模型 | |||
protected String model; | |||
protected String randomThreshold; | |||
/** | |||
* 构造函数(无参数) | |||
* 默认为星火大模型 | |||
*/ | |||
public LlmParam() { | |||
this.llm = "spark"; | |||
this.model = "generalv3"; | |||
this.maxTokens = 1024; | |||
this.temperature = 0.75; | |||
this.topK = 1; | |||
this.temperature = 0.5; | |||
this.topK = 4; | |||
} | |||
} |
@@ -1,13 +1,16 @@ | |||
package com.xueyi.nlt.nlt.service.impl; | |||
import com.xueyi.common.core.utils.core.SpringUtils; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import com.xueyi.nlt.netty.client.WebSocketClientManager; | |||
import com.xueyi.nlt.netty.client.listener.LlmWebSocketListener; | |||
import com.xueyi.nlt.netty.server.config.ServerConfig; | |||
import com.xueyi.nlt.nlt.domain.LlmContent; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
import com.xueyi.nlt.nlt.domain.LlmResponse; | |||
import com.xueyi.nlt.nlt.service.ISysLlmService; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.beans.factory.annotation.Value; | |||
import org.springframework.context.annotation.Primary; | |||
import org.springframework.data.redis.core.StringRedisTemplate; | |||
import org.springframework.stereotype.Service; | |||
@@ -20,26 +23,35 @@ import java.util.stream.Collectors; | |||
public class SparkServiceImpl implements ISysLlmService { | |||
@Autowired | |||
WebSocketClient webSocketClient; | |||
WebSocketClientManager webSocketClientManager; | |||
@Autowired | |||
private StringRedisTemplate redisTemplate; | |||
@Value("${secret.spark.appId}") | |||
private String appId; | |||
@Override | |||
public LlmResponse chat(LlmContext context, LlmParam param) { | |||
List<String> contentArr = context.getContentList().stream().map(LlmContent::getContent).collect(Collectors.toList()); | |||
WebSocketClient socketClient = SpringUtils.getBean(WebSocketClient.class); | |||
webSocketClient = socketClient.sendMsg(contentArr); | |||
synchronized (webSocketClient) { | |||
LlmWebSocketListener listener = new LlmWebSocketListener(appId, param, context,false); | |||
ServerConfig.currentTraceMap.put(context.getDevId(),context.getTraceId()); | |||
webSocketClientManager.startWebSocketClient(listener); | |||
synchronized (listener) { | |||
try { | |||
webSocketClient.wait(); | |||
listener.wait(); | |||
} catch (InterruptedException e) { | |||
e.printStackTrace(); | |||
Thread.currentThread().interrupt(); | |||
} | |||
String result = redisTemplate.opsForValue().get("group:websocket:content"); | |||
result = webSocketClient.answer; | |||
String result = listener.answer; | |||
// 添加上下文 | |||
// // 添加上下文 | |||
// INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", questions.get(questions.size() - 1)); | |||
// INSTANCE.redisTemplate.opsForList().rightPush("group:nlp:null:-1", answer); | |||
// // 添加缓存 | |||
// INSTANCE.stringRedisTemplate.opsForValue().set("group:websocket:content", answer); | |||
LlmResponse response = new LlmResponse(); | |||
response.setContent(result); | |||
return response; | |||
@@ -50,8 +62,10 @@ public class SparkServiceImpl implements ISysLlmService { | |||
@Override | |||
public LlmResponse stream(LlmContext context, LlmParam param) { | |||
List<String> contentArr = context.getContentList().stream().map(LlmContent::getContent).collect(Collectors.toList()); | |||
webSocketClient.sendMsg(contentArr, true,context.getDevId()); | |||
LlmResponse response = new LlmResponse(); | |||
LlmWebSocketListener listener = new LlmWebSocketListener(appId, param, context,true); | |||
ServerConfig.currentTraceMap.put(context.getDevId(),context.getTraceId()); | |||
webSocketClientManager.startWebSocketClient(listener); | |||
LlmResponse response = new LlmResponse(); | |||
return response; | |||
} | |||
} |
@@ -1,11 +1,9 @@ | |||
package com.xueyi.nlt.nlt.template; | |||
import com.alibaba.druid.util.StringUtils; | |||
import com.alibaba.fastjson2.JSONArray; | |||
import com.alibaba.fastjson2.JSONException; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.xueyi.common.core.context.SecurityContextHolder; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import com.xueyi.nlt.netty.client.WebSocketClientManager; | |||
import com.xueyi.nlt.nlt.context.TerminalSecurityContextHolder; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
@@ -49,87 +47,83 @@ public class FreeChatTemplate implements BaseTemplate{ | |||
Long operatorId = TerminalSecurityContextHolder.getOperatorId(); | |||
String redisKey = "group:nlp:" + SecurityContextHolder.getLocalMap().get("enterprise_id") + ":" + operatorId; | |||
// 根据content内容调用模版并返回结果 | |||
synchronized (WebSocketClient.LOCK) { | |||
// 通过redis获取数字人上下文信息 | |||
Long size = redisTemplate.opsForList().size(redisKey); | |||
if (size > 8) { | |||
redisTemplate.opsForList().leftPop(redisKey,2); | |||
} | |||
size = redisTemplate.opsForList().size(redisKey); | |||
List<String> context = new ArrayList<>(); | |||
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人。"); | |||
context.addAll(redisTemplate.opsForList().range(redisKey,size-6,size)); | |||
// 通过redis获取数字人上下文信息 | |||
Long size = redisTemplate.opsForList().size(redisKey); | |||
if (size > 8) { | |||
redisTemplate.opsForList().leftPop(redisKey,2); | |||
} | |||
size = redisTemplate.opsForList().size(redisKey); | |||
List<String> context = new ArrayList<>(); | |||
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人。"); | |||
context.addAll(redisTemplate.opsForList().range(redisKey,size-6,size)); | |||
context.add(content); | |||
context.add(content); | |||
//webSocketClient.sendMsg(context); | |||
//webSocketClient.sendMsg(context); | |||
LlmContext llmContext = LlmContext.parse(context,true); | |||
LlmParam param = new LlmParam(); | |||
LlmResponse response = sysLlmService.chat(llmContext,param); | |||
String result = response.getContent(); | |||
LlmContext llmContext = LlmContext.parse(context,true); | |||
LlmParam param = new LlmParam(); | |||
LlmResponse response = sysLlmService.chat(llmContext,param); | |||
String result = response.getContent(); | |||
// 处理数据 | |||
if (result.contains("我是科大讯飞")) { | |||
result = result.replaceAll("科大讯飞", "缔智元"); | |||
} | |||
result = result.replaceAll("认知模型", "数字员工"); | |||
result = result.replaceAll("认知智能模型", "数字员工"); | |||
if (result.equals("-1")) { | |||
result = "这个问题超出了我无法回答,您可以提出更多关于电影相关问题。"; | |||
} | |||
// 处理数据 | |||
if (result.contains("我是科大讯飞")) { | |||
result = result.replaceAll("科大讯飞", "缔智元"); | |||
} | |||
result = result.replaceAll("认知模型", "数字员工"); | |||
result = result.replaceAll("认知智能模型", "数字员工"); | |||
if (result.equals("-1")) { | |||
result = "这个问题超出了我无法回答,您可以提出更多关于电影相关问题。"; | |||
} | |||
if(!StringUtils.isEmpty(result)){ | |||
redisTemplate.opsForList().rightPush(redisKey,content); | |||
redisTemplate.opsForList().rightPush(redisKey,result); | |||
} | |||
JSONObject resultJson = new JSONObject(); | |||
resultJson.put("msg",result); | |||
return resultJson; | |||
if(!StringUtils.isEmpty(result)){ | |||
redisTemplate.opsForList().rightPush(redisKey,content); | |||
redisTemplate.opsForList().rightPush(redisKey,result); | |||
} | |||
JSONObject resultJson = new JSONObject(); | |||
resultJson.put("msg",result); | |||
return resultJson; | |||
} | |||
public JSONObject handle(String dev, String content, boolean stream) { | |||
Long operatorId = TerminalSecurityContextHolder.getOperatorId(); | |||
String redisKey = "group:nlp:" + SecurityContextHolder.getLocalMap().get("enterprise_id") + ":" + operatorId; | |||
// 根据content内容调用模版并返回结果 | |||
synchronized (WebSocketClient.LOCK) { | |||
// 通过redis获取数字人上下文信息 | |||
Long size = redisTemplate.opsForList().size(redisKey); | |||
if (size > 8) { | |||
redisTemplate.opsForList().leftPop(redisKey,2); | |||
} | |||
size = redisTemplate.opsForList().size(redisKey); | |||
List<String> context = new ArrayList<>(); | |||
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人,你们公司在北京,你的职能是负责做会议预定和访客预约。"); | |||
// 中航信领导来访临时对策 | |||
// 判断如果content包含correctWordsMap中的key,则替换为value | |||
for (String key : correctWordsMap.keySet()) { | |||
if (content.contains(key)) { | |||
content = content.replaceAll(key, correctWordsMap.get(key)); | |||
} | |||
// 通过redis获取数字人上下文信息 | |||
Long size = redisTemplate.opsForList().size(redisKey); | |||
if (size > 8) { | |||
redisTemplate.opsForList().leftPop(redisKey,2); | |||
} | |||
size = redisTemplate.opsForList().size(redisKey); | |||
List<String> context = new ArrayList<>(); | |||
context.add("你是缔智元公司的前台,你叫小智,你是一位数字人,你们公司在北京,你的职能是负责做会议预定和访客预约。"); | |||
// 中航信领导来访临时对策 | |||
// 判断如果content包含correctWordsMap中的key,则替换为value | |||
for (String key : correctWordsMap.keySet()) { | |||
if (content.contains(key)) { | |||
content = content.replaceAll(key, correctWordsMap.get(key)); | |||
} | |||
//context.addAll(redisTemplate.opsForList().range(redisKey,size-2,size)); | |||
content = "请用简短的话回答下面的问题:" + content; | |||
context.add(content); | |||
//使用stream去除context列表中所有字符串中的引号 | |||
context = context.stream().map(s -> s.replaceAll("\"", "")).collect(java.util.stream.Collectors.toList()); | |||
//webSocketClient.sendMsg(context); | |||
LlmContext llmContext = LlmContext.parse(context,true); | |||
llmContext.setDevId(dev); | |||
LlmParam param = new LlmParam(); | |||
LlmResponse response = sysLlmService.stream(llmContext,param); | |||
JSONObject resultJson = new JSONObject(); | |||
resultJson.put("tts","让我想一想。"); | |||
resultJson.put("motion","idle"); | |||
resultJson.put("status","0"); | |||
resultJson.put("action","chat"); | |||
return resultJson; | |||
} | |||
//context.addAll(redisTemplate.opsForList().range(redisKey,size-2,size)); | |||
content = "请用简短的话回答下面的问题:" + content; | |||
context.add(content); | |||
//使用stream去除context列表中所有字符串中的引号 | |||
context = context.stream().map(s -> s.replaceAll("\"", "")).collect(java.util.stream.Collectors.toList()); | |||
//webSocketClient.sendMsg(context); | |||
LlmContext llmContext = LlmContext.parse(context,true); | |||
llmContext.setDevId(dev); | |||
LlmParam param = new LlmParam(); | |||
LlmResponse response = sysLlmService.stream(llmContext,param); | |||
JSONObject resultJson = new JSONObject(); | |||
resultJson.put("tts","让我想一想。"); | |||
resultJson.put("motion","idle"); | |||
resultJson.put("status","0"); | |||
resultJson.put("action","chat"); | |||
return resultJson; | |||
} | |||
@Override | |||
public JSONObject handle(String dev, String content, Long tenantId) { | |||
@@ -3,7 +3,11 @@ package com.xueyi.nlt.nlt.template; | |||
import com.alibaba.fastjson2.JSONArray; | |||
import com.alibaba.fastjson2.JSONException; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import com.xueyi.nlt.netty.client.WebSocketClientManager; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
import com.xueyi.nlt.nlt.domain.LlmResponse; | |||
import com.xueyi.nlt.nlt.service.ISysLlmService; | |||
import org.slf4j.Logger; | |||
import org.slf4j.LoggerFactory; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
@@ -16,41 +20,34 @@ public class GenerativeKnowledgeTemplate implements BaseTemplate{ | |||
private static final Logger log = LoggerFactory.getLogger(GenerativeKnowledgeTemplate.class); | |||
@Autowired | |||
WebSocketClient webSocketClient; | |||
WebSocketClientManager webSocketClientManager; | |||
@Autowired | |||
private ISysLlmService sysLlmService; | |||
@Autowired | |||
private RedisTemplate<String,String> redisTemplate; | |||
@Override | |||
public JSONObject handle(String devId, String content) { | |||
JSONObject jsonObject = new JSONObject(); | |||
// 根据content内容调用模版并返回结果 | |||
synchronized (WebSocketClient.LOCK) { | |||
String prefix = "你的任务是[针对给定的文段提出" + (content.length() /100 + 1 ) + "个问题并回答]。文段为:[\""; | |||
String suffix = "\"]。输出为一个JSON数组[{}],每个元素是一个JSON:{“question”:,”answer”:}。不要给出任何解释说明。"; | |||
log.info(prefix + content + suffix); | |||
webSocketClient.sendMsg(prefix + content + suffix); | |||
try { | |||
WebSocketClient.LOCK.wait(); | |||
String result = (String)redisTemplate.opsForValue().get("group:websocket:content"); | |||
try { | |||
JSONArray jsonArray = JSONArray.parseArray(result); | |||
JSONObject jsonObject = new JSONObject(); | |||
jsonObject.put("questions",jsonArray); | |||
return jsonObject; | |||
} catch (JSONException je) { | |||
// 返回结果错误,计日志,存log,返回空结果 | |||
log.error(je.getMessage(),je); | |||
return new JSONObject(); | |||
} | |||
} catch (InterruptedException e) { | |||
log.warn(e.getMessage()); | |||
Thread.currentThread().interrupt(); | |||
} | |||
String prefix = "你的任务是[针对给定的文段提出" + (content.length() /100 + 1 ) + "个问题并回答]。文段为:[\""; | |||
String suffix = "\"]。输出为一个JSON数组[{}],每个元素是一个JSON:{“question”:,”answer”:}。不要给出任何解释说明。"; | |||
log.info(prefix + content + suffix); | |||
LlmContext llmContext = new LlmContext(prefix + content + suffix); | |||
llmContext.setDevId(devId); | |||
LlmParam llmParam = new LlmParam(); | |||
LlmResponse response = sysLlmService.chat(llmContext,llmParam); | |||
try { | |||
JSONArray jsonArray = JSONArray.parseArray(response.getContent()); | |||
jsonObject.put("questions",jsonArray); | |||
return jsonObject; | |||
} catch (JSONException je) { | |||
// 返回结果错误,计日志,存log,返回空结果 | |||
log.error(je.getMessage(),je); | |||
} | |||
return new JSONObject(); | |||
return jsonObject; | |||
} | |||
@Override | |||
@@ -2,9 +2,7 @@ package com.xueyi.nlt.nlt.template; | |||
import co.elastic.clients.elasticsearch.ElasticsearchClient; | |||
import co.elastic.clients.elasticsearch._types.query_dsl.MatchQuery; | |||
import co.elastic.clients.elasticsearch._types.query_dsl.Query; | |||
import co.elastic.clients.elasticsearch.core.SearchResponse; | |||
import co.elastic.clients.elasticsearch.core.search.Hit; | |||
import com.alibaba.cloud.nacos.NacosConfigManager; | |||
import com.alibaba.fastjson2.JSONArray; | |||
import com.alibaba.fastjson2.JSONException; | |||
@@ -15,14 +13,8 @@ import com.xueyi.common.core.constant.basic.SecurityConstants; | |||
import com.xueyi.common.core.constant.digitalman.SkillConstants; | |||
import com.xueyi.common.core.web.result.R; | |||
import com.xueyi.nlt.api.nlt.domain.vo.CoversationSessionVo; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import com.xueyi.nlt.nlt.config.MeetingParam; | |||
import com.xueyi.nlt.nlt.controller.DmIntentController; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
import com.xueyi.nlt.nlt.domain.LlmResponse; | |||
import com.xueyi.nlt.netty.client.WebSocketClientManager; | |||
import com.xueyi.nlt.nlt.domain.vo.MeetingParamVo; | |||
import com.xueyi.nlt.nlt.service.ISysLlmService; | |||
import com.xueyi.system.api.digitalmans.domain.po.DmDigitalmanExtPo; | |||
import com.xueyi.system.api.digitalmans.feign.RemoteDigitalmanService; | |||
import com.xueyi.system.api.meeting.domain.dto.DmMeetingRoomsDto; | |||
@@ -30,23 +22,19 @@ import com.xueyi.system.api.meeting.feign.RemoteMeetingService; | |||
import com.xueyi.system.api.model.Source; | |||
import com.xueyi.system.api.organize.domain.dto.SysEnterpriseDto; | |||
import com.xueyi.system.api.organize.feign.RemoteEnterpriseService; | |||
import com.xueyi.system.api.pass.domain.po.DmRecognizedRecordsPo; | |||
import org.slf4j.Logger; | |||
import org.slf4j.LoggerFactory; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.data.redis.core.RedisTemplate; | |||
import org.springframework.data.redis.core.StringRedisTemplate; | |||
import org.springframework.stereotype.Component; | |||
import org.springframework.stereotype.Service; | |||
import javax.annotation.PostConstruct; | |||
import java.io.IOException; | |||
import java.time.LocalDate; | |||
import java.time.LocalDateTime; | |||
import java.time.format.DateTimeFormatter; | |||
import java.util.ArrayList; | |||
import java.util.Arrays; | |||
import java.util.Comparator; | |||
import java.util.List; | |||
import java.util.concurrent.TimeUnit; | |||
import java.util.regex.Matcher; | |||
@@ -59,7 +47,7 @@ public class MeetingOrderTemplate implements BaseTemplate { | |||
private static final Logger log = LoggerFactory.getLogger(MeetingOrderTemplate.class); | |||
@Autowired | |||
WebSocketClient webSocketClient; | |||
WebSocketClientManager webSocketClientManager; | |||
// @Autowired | |||
// RemoteMeetingService remoteMeetingService; | |||
@@ -3,7 +3,7 @@ package com.xueyi.nlt.nlt.template; | |||
import com.alibaba.druid.util.StringUtils; | |||
import com.alibaba.fastjson2.JSONObject; | |||
import com.xueyi.common.core.context.SecurityContextHolder; | |||
import com.xueyi.nlt.netty.client.WebSocketClient; | |||
import com.xueyi.nlt.netty.client.WebSocketClientManager; | |||
import com.xueyi.nlt.nlt.context.TerminalSecurityContextHolder; | |||
import com.xueyi.nlt.nlt.domain.LlmContext; | |||
import com.xueyi.nlt.nlt.domain.LlmParam; | |||
@@ -24,7 +24,7 @@ public class MovieChatTemplate implements BaseTemplate{ | |||
private static final Logger log = LoggerFactory.getLogger(MovieChatTemplate.class); | |||
@Autowired | |||
WebSocketClient webSocketClient; | |||
WebSocketClientManager webSocketClientManager; | |||
@Autowired | |||
ISysLlmService sysLlmService; | |||