评分系统分流部分的代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

222 lines
8.8 KiB

  1. # coding=utf-8
  2. import numpy as np
  3. import configparser
  4. from datetime import datetime
  5. import json
  6. from tools.logger import Logger
  7. LTHRESHOLD = 0.4
  8. MTHRESHOLD = 0.6
  9. HTHRESHOLD = 0.8
  10. DB = 'xy-cloud1'
  11. INDEX = 'dm_q_and_a,dm_questions'
  12. cache = {}
  13. class Knowledge():
  14. def __init__(self, model):
  15. self.model = model + 'llm'
  16. self.config = configparser.ConfigParser()
  17. self.config.read("settings.ini", encoding="utf-8")
  18. self.sourcename = self.config.get('config', 'source') + 'source'
  19. source_config = dict(self.config.items(self.sourcename))
  20. source_config['index'] = INDEX
  21. source_config['db'] = DB
  22. self.sources = __import__('sources.%s' % self.sourcename,
  23. fromlist=['sources'])
  24. string = 'self.sources.' + self.sourcename.capitalize()
  25. self.source = eval(string)(**source_config)
  26. self.logger = Logger(self.config.get("config", "logger_path"))
  27. self.emname = self.config.get('config', 'embedding') + 'embedding'
  28. embedding_config = dict(self.config.items(self.emname))
  29. self.ems = __import__('embeddings.%s' % self.emname,
  30. fromlist=['embeddings'])
  31. string = 'self.ems.' + self.emname.capitalize()
  32. self.embedding = eval(string)(**embedding_config)
  33. self.llmname = self.model
  34. llm_config = dict(self.config.items(self.llmname))
  35. self.llms = __import__('llms.%s' % self.llmname,
  36. fromlist=['llms'])
  37. string = 'self.llms.' + self.llmname.capitalize()
  38. self.llm = eval(string)(**llm_config)
  39. def combine(self, result, question, accurate, tenantId):
  40. Result = {'code': 200, 'target': 1}
  41. data = {}
  42. if not result == '':
  43. data['result'] = result
  44. data['question'] = question
  45. data['accurate'] = accurate
  46. data['llm'] = self.config.get('config', 'model')
  47. data['tenantId'] = tenantId
  48. Result['data'] = data
  49. return Result
  50. def log_header(self, method):
  51. timestamp = datetime.strftime(datetime.now(), '%Y-%m-%dT%H:%M:%S:%f')
  52. timestamp = '"timestamp":"' + timestamp + '",'
  53. tenant = '"tentant_id":' + str(self.tenant_id) + ','
  54. em = '"embedding":"' + self.emname + '",'
  55. model = '"model":"' + self.llmname + '",'
  56. source = '"source":"' + self.sourcename + '",'
  57. questionstr = '"question":"' + self.question + '","answers":['
  58. logs = ',"method":' + method + ','
  59. logs = timestamp + tenant + em + model + source + logs + questionstr
  60. return logs
  61. def emsearch(self):
  62. data_all = self.source.getdata(tenant_id=self.tenant_id)
  63. logs = self.log_header('"emlist"')
  64. for i in data_all:
  65. item = '"' + i['name'] + '",' + i['id']
  66. logs = logs + '[' + item + '],'
  67. logs = logs[0:-1] if logs[-1] == ',' else logs
  68. self.logger.info(logs + ']')
  69. em_list = []
  70. score_em_max = -2
  71. if len(data_all) > 0:
  72. all_em_score = np.zeros(len(data_all))
  73. v1 = self.embedding.getem(self.question)
  74. for i in range(len(data_all)):
  75. if data_all[i]['name'] in cache.keys():
  76. v2 = cache[data_all[i]['name']]
  77. else:
  78. v2 = self.embedding.getem(data_all[i]['name'])
  79. cache[data_all[i]['name']] = v2
  80. numerator = np.dot(v1, v2)
  81. denominator = (np.linalg.norm(v1) * np.linalg.norm(v2))
  82. all_em_score[i] = numerator / denominator
  83. logs = self.log_header('"emsearch"')
  84. for i in range(6):
  85. t = np.argmax(all_em_score)
  86. if all_em_score[t] < LTHRESHOLD:
  87. break
  88. if i == 0:
  89. score_em_max = all_em_score[t]
  90. if all_em_score[t] <= -1:
  91. break
  92. em_list.append(data_all[t])
  93. item = '"' + data_all[t]['name'] + '",'
  94. item = item + data_all[t]['id']
  95. item = item + ',%.3f' % all_em_score[t]
  96. logs = logs + '[' + item + '],'
  97. all_em_score[t] = -2
  98. logs = logs[0:-1] if logs[-1] == ',' else logs
  99. self.logger.info(logs + ']')
  100. return [em_list, score_em_max]
  101. def recommend(self):
  102. logs = self.log_header('"searchbegin"')
  103. self.logger.info(logs + ']')
  104. [result_list, score_em_max] = self.emsearch()
  105. if score_em_max > HTHRESHOLD:
  106. result_list[0]['highlight'] = 1
  107. logs = self.log_header('"recommend"')
  108. for i in result_list:
  109. logs = logs + '["' + i['name'] + '",' + i['id'] + '],'
  110. logs = logs[0:-1] if logs[-1] == ',' else logs
  111. self.logger.info(logs + ']')
  112. return self.combine(result_list[0:4],
  113. self.question,
  114. 1,
  115. self.tenant_id)
  116. if score_em_max > MTHRESHOLD:
  117. result_list[0]['highlight'] = 1
  118. logs = self.log_header('"recommendlist"')
  119. for i in result_list:
  120. logs = logs + '["' + i['name'] + '",' + i['id'] + '],'
  121. logs = logs[0:-1] if logs[-1] == ',' else logs
  122. self.logger.info(logs + ']')
  123. if len(result_list) == 0:
  124. logs = self.log_header('"recommend"')
  125. self.logger.info(logs + ']')
  126. return self.combine("",
  127. self.question,
  128. 0,
  129. self.tenant_id)
  130. if len(result_list) <= 4:
  131. logs = self.log_header('"recommend"')
  132. for i in result_list:
  133. logs = logs + '["' + i['name'] + '",' + i['id'] + '],'
  134. logs = logs[0:-1] if logs[-1] == ',' else logs
  135. self.logger.info(logs + ']')
  136. return self.combine(result_list,
  137. self.question,
  138. 0,
  139. self.tenant_id)
  140. if self.config.get('config', 'usellm') == "0":
  141. logs = self.log_header('"recommend"')
  142. for i in result_list:
  143. logs = logs + '["' + i['name'] + '",' + i['id'] + '],'
  144. logs = logs[0:-1] if logs[-1] == ',' else logs
  145. self.logger.info(logs + ']')
  146. return self.combine(result_list[0:4],
  147. self.question,
  148. 0,
  149. self.tenant_id)
  150. L = ''
  151. for i in result_list:
  152. L = L + '"' + i['name'] + '",'
  153. L = L[0:-1]
  154. Q1 = "请从列表{"
  155. Q2 = "}选出与'"
  156. Q3 = "'意图最接近的四句话,将结果以{question1:,question2:,question3:,question4:}输出。"
  157. Q4 = '输出为JSON格式。答案只能来自于列表。不要返回代码。不要输出JSON之外的东西'
  158. Q = Q1 + L + Q2 + self.question + Q3 + Q4
  159. logs = self.log_header('"llmin"')
  160. self.logger.info(logs + '"' + Q + '"]')
  161. answer, tokens = self.llm.link(Q)
  162. logs = self.log_header('"llmout"')
  163. tokens = str(tokens).replace("'", '"')
  164. logs = logs[0:-1] + repr(answer) + ',"tokens":' + str(tokens)
  165. self.logger.info(logs)
  166. begin = answer.find('{')
  167. end = answer.rfind('}')
  168. answer = answer[begin:end+1]
  169. answer = answer.replace('\\n', '')
  170. answer = answer.replace('\\"', '"')
  171. logs = self.log_header('"llmsearch"')
  172. answer = answer.replace("'", '"')
  173. try:
  174. data = json.loads(answer)
  175. result = []
  176. for key in data:
  177. for i in result_list:
  178. if data[key] == i['name']:
  179. if i not in result:
  180. result.append(i)
  181. except Exception:
  182. result = result_list[0:4]
  183. err_logs = self.log_header('"llmsearch"')
  184. err_logs = err_logs + '"Can not trans to JSON.]"'
  185. self.logger.error(err_logs)
  186. for i in result:
  187. logs = logs + '["' + i['name'] + '",' + i['id'] + '],'
  188. while len(result) < 4:
  189. for i in result_list:
  190. if i not in result:
  191. result.append(i)
  192. break
  193. if len(result) > 4:
  194. result = result[0:4]
  195. logs = logs[0:-1] if logs[-1] == ',' else logs
  196. self.logger.info(logs + ']')
  197. logs = self.log_header('"recommend"')
  198. for i in result:
  199. logs = logs + '["' + i['name'] + '",' + i['id'] + '],'
  200. logs = logs[0:-1] if logs[-1] == ',' else logs
  201. self.logger.info(logs + ']')
  202. return self.combine(result,
  203. self.question,
  204. 0,
  205. self.tenant_id)