大模型接口远程访问
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

327 lines
11 KiB

  1. """
  2. A simple wrapper for the official ChatGPT API
  3. https://github.com/acheong08/ChatGPT.git v3.py
  4. """
  5. import json
  6. import os
  7. import sys
  8. import platform
  9. from typing import NoReturn
  10. import requests
  11. import tiktoken #解析成token令牌
  12. class Chatbot_V3:
  13. """
  14. Official ChatGPT API
  15. """
  16. def __init__(
  17. self,
  18. api_key: str,
  19. engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
  20. proxy: str = None,
  21. max_tokens: int = 3000,
  22. temperature: float = 0.5,
  23. top_p: float = 1.0,
  24. presence_penalty: float = 0.0,
  25. frequency_penalty: float = 0.0,
  26. reply_count: int = 1,
  27. system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
  28. ) -> None:
  29. """
  30. Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
  31. """
  32. self.engine = engine
  33. self.session = requests.Session()
  34. self.api_key = api_key
  35. self.system_prompt = system_prompt
  36. self.max_tokens = max_tokens
  37. self.temperature = temperature
  38. self.top_p = top_p
  39. self.presence_penalty = presence_penalty
  40. self.frequency_penalty = frequency_penalty
  41. self.reply_count = reply_count
  42. if proxy:
  43. self.session.proxies = {
  44. "http": proxy,
  45. "https": proxy,
  46. }
  47. self.conversation: dict = {
  48. "default": [
  49. {
  50. "role": "system",
  51. "content": system_prompt,
  52. },
  53. ],
  54. }
  55. if max_tokens > 4000:
  56. raise Exception("Max tokens cannot be greater than 4000")
  57. if self.get_token_count("default") > self.max_tokens:
  58. raise Exception("System prompt is too long")
  59. self.skills = {}
  60. # 判断是否存在skills.csv文件,如果存在则读取key,value并加载进skills字典
  61. print("加载skills.csv文件")
  62. if os.path.exists("/Users/ruoxiyin/Documents/缔智元/代码/chatGPT_Web-main/skills.csv"):
  63. with open("/Users/ruoxiyin/Documents/缔智元/代码/chatGPT_Web-main/skills.csv", "r", encoding="utf-8") as f:
  64. for line in f.readlines():
  65. key, value = line.strip().split("|")
  66. self.skills[key] = value
  67. def add_to_conversation(
  68. self,
  69. message: str,
  70. role: str,
  71. convo_id: str = "default",
  72. ) -> None:
  73. """
  74. Add a message to the conversation
  75. """
  76. self.conversation[convo_id].append({"role": role, "content": message})
  77. def get_conversation(
  78. self,
  79. convo_id: str = "default",
  80. ) -> list:
  81. """
  82. Get the conversation
  83. """
  84. return self.conversation[convo_id]
  85. def __truncate_conversation(self, convo_id: str = "default") -> None:
  86. """
  87. Truncate the conversation
  88. """
  89. while True:
  90. if (
  91. self.get_token_count(convo_id) > self.max_tokens
  92. and len(self.conversation[convo_id]) > 1
  93. ):
  94. # Don't remove the first message
  95. self.conversation[convo_id].pop(1)
  96. else:
  97. break
  98. # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  99. def get_token_count(self, convo_id: str = "default") -> int:
  100. """
  101. Get token count
  102. """
  103. if self.engine not in ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"]:
  104. raise NotImplementedError("Unsupported engine {self.engine}")
  105. encoding = tiktoken.encoding_for_model(self.engine)
  106. num_tokens = 0
  107. for message in self.conversation[convo_id]:
  108. # every message follows <im_start>{role/name}\n{content}<im_end>\n
  109. num_tokens += 4
  110. for key, value in message.items():
  111. num_tokens += len(encoding.encode(value))
  112. if key == "name": # if there's a name, the role is omitted
  113. num_tokens += 1 # role is always required and always 1 token
  114. num_tokens += 2 # every reply is primed with <im_start>assistant
  115. return num_tokens
  116. def get_max_tokens(self, convo_id: str) -> int:
  117. """
  118. Get max tokens
  119. """
  120. return self.max_tokens - self.get_token_count(convo_id)
  121. def ask_stream(
  122. self,
  123. prompt: str,
  124. role: str = "user",
  125. convo_id: str = "default",
  126. **kwargs,
  127. ) -> str:
  128. """
  129. Ask a question
  130. """
  131. # Make conversation if it doesn't exist
  132. if convo_id not in self.conversation:
  133. self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
  134. self.add_to_conversation(prompt, "user", convo_id=convo_id)
  135. self.__truncate_conversation(convo_id=convo_id)
  136. print(convo_id)
  137. # Get response
  138. response = self.session.post(
  139. os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
  140. headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
  141. json={
  142. "model": self.engine,
  143. "messages": self.conversation[convo_id],
  144. "stream": True,
  145. # kwargs
  146. "temperature": kwargs.get("temperature", self.temperature),
  147. "top_p": kwargs.get("top_p", self.top_p),
  148. "presence_penalty": kwargs.get(
  149. "presence_penalty",
  150. self.presence_penalty,
  151. ),
  152. "frequency_penalty": kwargs.get(
  153. "frequency_penalty",
  154. self.frequency_penalty,
  155. ),
  156. "n": kwargs.get("n", self.reply_count),
  157. "user": role,
  158. "max_tokens": self.get_max_tokens(convo_id=convo_id),
  159. },
  160. stream=True,
  161. )
  162. if response.status_code != 200:
  163. raise Exception(
  164. f"Error: {response.status_code} {response.reason} {response.text}",
  165. )
  166. response_role: str = None
  167. full_response: str = ""
  168. for line in response.iter_lines():
  169. if not line:
  170. continue
  171. # Remove "data: "
  172. line = line.decode("utf-8")[6:]
  173. if line == "[DONE]":
  174. break
  175. resp: dict = json.loads(line)
  176. choices = resp.get("choices")
  177. if not choices:
  178. continue
  179. delta = choices[0].get("delta")
  180. if not delta:
  181. continue
  182. if "role" in delta:
  183. response_role = delta["role"]
  184. if "content" in delta:
  185. content = delta["content"]
  186. full_response += content
  187. yield content
  188. self.add_to_conversation(full_response, response_role, convo_id=convo_id)
  189. def ask_stream_text(
  190. self,
  191. prompt: str,
  192. role: str = "user",
  193. convo_id: str = "default",
  194. **kwargs,
  195. ) -> str:
  196. """
  197. Ask a question, push as Streaming text
  198. """
  199. # Make conversation if it doesn't exist
  200. if convo_id not in self.conversation:
  201. self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
  202. # 去除上下文接入
  203. self.add_to_conversation(prompt, "user", convo_id=convo_id)
  204. self.__truncate_conversation(convo_id=convo_id)
  205. # Get response
  206. response = self.session.post(
  207. os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
  208. headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
  209. json={
  210. "model": self.engine,
  211. "messages": self.conversation[convo_id],
  212. "stream": True,
  213. # kwargs
  214. "temperature": kwargs.get("temperature", self.temperature),
  215. "top_p": kwargs.get("top_p", self.top_p),
  216. "presence_penalty": kwargs.get(
  217. "presence_penalty",
  218. self.presence_penalty,
  219. ),
  220. "frequency_penalty": kwargs.get(
  221. "frequency_penalty",
  222. self.frequency_penalty,
  223. ),
  224. "n": kwargs.get("n", self.reply_count),
  225. "user": role,
  226. "max_tokens": self.get_max_tokens(convo_id=convo_id),
  227. },
  228. stream=True,
  229. )
  230. if response.status_code != 200:
  231. raise Exception(
  232. f"Error: {response.status_code} {response.reason} {response.text}",
  233. )
  234. response_role: str = None
  235. full_response: str = ""
  236. for line in response.iter_lines():
  237. if not line:
  238. continue
  239. line = line.decode("utf-8")[6:]
  240. if line == "[DONE]":
  241. break;
  242. resp: dict = json.loads(line)
  243. choices = resp.get("choices")
  244. if not choices:
  245. continue
  246. delta = choices[0].get("delta")
  247. if not delta:
  248. continue
  249. if "role" in delta:
  250. response_role = delta["role"]
  251. if "content" in delta:
  252. content = delta['content']
  253. full_response += content
  254. yield 'data: {}\n\n'.format(json.dumps({'content':content})) #Got the answer content to push
  255. yield 'data: {}\n\n'.format(json.dumps({'content':'[DONE]'})) #The answer is over, pushing [DONE]
  256. # 去除上下文接入
  257. self.conversation[convo_id].clear()
  258. # self.add_to_conversation(full_response, response_role, convo_id=convo_id) #Add the answer to the conversation
  259. def ask(
  260. self,
  261. prompt: str,
  262. role: str = "user",
  263. convo_id: str = "default",
  264. **kwargs,
  265. ) -> str:
  266. """
  267. Non-streaming ask
  268. """
  269. response = self.ask_stream(
  270. prompt=prompt,
  271. role=role,
  272. convo_id=convo_id,
  273. **kwargs,
  274. )
  275. full_response: str = "".join(response)
  276. return full_response
  277. def rollback(self, n: int = 1, convo_id: str = "default") -> None:
  278. """
  279. Rollback the conversation
  280. """
  281. for _ in range(n):
  282. self.conversation[convo_id].pop()
  283. def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
  284. """
  285. Reset the conversation
  286. """
  287. self.conversation[convo_id] = [
  288. {"role": "system", "content": system_prompt or self.system_prompt},
  289. ]
  290. if __name__ == "__main__":
  291. #获取config.json中的数据
  292. path = os.path.split(os.path.realpath(__file__))[0]
  293. with open(path + '/config.json', 'r') as f:
  294. config = json.load(f)
  295. if platform.system().lower() == 'linux':
  296. proxys = None
  297. else:
  298. proxys = config['proxy']
  299. chat = Chatbot_V3(config['key'], "gpt-3.5-turbo", proxy=proxys)
  300. while True:
  301. promt = input("请输入:")
  302. if promt == 'exit':
  303. break;
  304. res = chat.ask(promt)
  305. print(res)