Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

120 Zeilen
3.5 KiB

  1. import torch
  2. import dashscope
  3. from http import HTTPStatus
  4. from dashscope import TextEmbedding
  5. from torch import nn
  6. import torch.optim as optim
  7. import pandas as pd
  8. import numpy as np
  9. dashscope.api_key = 'sk-44ccc9ab5e754eddb545cade12b632cf'
  10. cache = {}
  11. answerCache = []
  12. def getem(question):
  13. global cache
  14. if question in cache.keys():
  15. return cache[question]
  16. resp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,
  17. input=question,
  18. text_type='query')
  19. if resp.status_code == HTTPStatus.OK:
  20. cache[question] = resp['output']['embeddings'][0]['embedding']
  21. return resp['output']['embeddings'][0]['embedding']
  22. class ConvNet(nn.Module):
  23. def __init__(self):
  24. super(ConvNet, self).__init__()
  25. self.conv1 = nn.Conv1d(2, 1, kernel_size=1, stride=1, padding=0)
  26. self.relu1 = nn.ReLU()
  27. self.conv2 = nn.Conv1d(2, 1, kernel_size=1, stride=1, padding=0)
  28. self.relu2 = nn.ReLU()
  29. self.conv3 = nn.Conv1d(2, 1, kernel_size=1, stride=1, padding=0)
  30. self.relu3 = nn.ReLU()
  31. self.fc = nn.Linear(1536 * 3, 2)
  32. self.sigmoid = nn.Sigmoid()
  33. def forward(self, x):
  34. X1 = self.conv1(x)
  35. X1 = self.relu1(X1)
  36. X2 = self.conv2(x)
  37. X2 = self.relu1(X2)
  38. X3 = self.conv3(x)
  39. X3 = self.relu1(X3)
  40. X = torch.cat([X1, X2, X3], dim=2)
  41. X = X.view(-1, 1536 * 3)
  42. X = self.fc(X)
  43. X = self.sigmoid(X)
  44. return X
  45. # 创建模型实例
  46. model = ConvNet()
  47. # 定义损失函数和优化器
  48. criterion = nn.MSELoss()
  49. optimizer = optim.Adam(model.parameters(), lr=0.001)
  50. train = []
  51. test = []
  52. dataall = pd.read_csv('data.csv')
  53. dataall = dataall.iloc[:, 1:4]
  54. dataall = dataall.sample(frac=1)
  55. train = dataall.iloc[0:300]
  56. train = train.reset_index(drop=True)
  57. test = dataall.iloc[300:]
  58. test = test.reset_index(drop=True)
  59. nlossLast = 0
  60. for i in range(5):
  61. nloss = 0
  62. for k in range(len(train)):
  63. va = getem(train.iloc[k]['question'])
  64. vb = getem(train.iloc[k]['answer'])
  65. if train.iloc[k]['answer'] not in answerCache:
  66. answerCache.append(train.iloc[k]['answer'])
  67. trainTensor = torch.Tensor([va, vb]).reshape([1, 2, len(va)])
  68. output = model(trainTensor)
  69. # 计算损失
  70. if train.iloc[k]['label'] == 1:
  71. loss = criterion(output,
  72. torch.tensor([1, 0]).float().reshape([1, 2]))
  73. else:
  74. loss = criterion(output,
  75. torch.tensor([0, 1]).float().reshape([1, 2]))
  76. # 反向传播并更新权重
  77. optimizer.zero_grad()
  78. loss.backward()
  79. optimizer.step()
  80. nloss += loss
  81. if k % 50 == 0:
  82. print(i, k, 'done')
  83. print('one loop done', nloss/len(train))
  84. p = 0
  85. for i in range(len(test)):
  86. va = getem(test.iloc[i]['question'])
  87. Scores = np.zeros(len(answerCache))
  88. for j in range(len(answerCache)):
  89. vb = getem(answerCache[j])
  90. testTensor = torch.Tensor([va, vb]).reshape([1, 2, len(va)])
  91. output = model(testTensor)
  92. Scores[j] = output[0][0]
  93. for k in range(2):
  94. if test.iloc[i]['label'] == 1:
  95. vc = test.iloc[i]['answer']
  96. else:
  97. vc = ''
  98. tt = Scores.argmax()
  99. if Scores[tt] > 0.5:
  100. vb = answerCache[tt]
  101. Scores[tt] = -1
  102. else:
  103. vb = ''
  104. if vb == vc:
  105. p += 1
  106. break
  107. print(p/len(test))