25'ten fazla konu seçemezsiniz Konular bir harf veya rakamla başlamalı, kısa çizgiler ('-') içerebilir ve en fazla 35 karakter uzunluğunda olabilir.

120 satır
3.5 KiB

  1. import torch
  2. import dashscope
  3. from http import HTTPStatus
  4. from dashscope import TextEmbedding
  5. from torch import nn
  6. import torch.optim as optim
  7. import pandas as pd
  8. import numpy as np
  9. dashscope.api_key = 'sk-44ccc9ab5e754eddb545cade12b632cf'
  10. cache = {}
  11. answerCache = []
  12. def getem(question):
  13. global cache
  14. if question in cache.keys():
  15. return cache[question]
  16. resp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,
  17. input=question,
  18. text_type='query')
  19. if resp.status_code == HTTPStatus.OK:
  20. cache[question] = resp['output']['embeddings'][0]['embedding']
  21. return resp['output']['embeddings'][0]['embedding']
  22. class ConvNet(nn.Module):
  23. def __init__(self):
  24. super(ConvNet, self).__init__()
  25. self.conv1 = nn.Conv1d(2, 1, kernel_size=1, stride=1, padding=0)
  26. self.relu1 = nn.ReLU()
  27. self.conv2 = nn.Conv1d(2, 1, kernel_size=1, stride=1, padding=0)
  28. self.relu2 = nn.ReLU()
  29. self.conv3 = nn.Conv1d(2, 1, kernel_size=1, stride=1, padding=0)
  30. self.relu3 = nn.ReLU()
  31. self.fc = nn.Linear(1536 * 3, 2)
  32. self.sigmoid = nn.Sigmoid()
  33. def forward(self, x):
  34. X1 = self.conv1(x)
  35. X1 = self.relu1(X1)
  36. X2 = self.conv2(x)
  37. X2 = self.relu1(X2)
  38. X3 = self.conv3(x)
  39. X3 = self.relu1(X3)
  40. X = torch.cat([X1, X2, X3], dim=2)
  41. X = X.view(-1, 1536 * 3)
  42. X = self.fc(X)
  43. X = self.sigmoid(X)
  44. return X
  45. # 创建模型实例
  46. model = ConvNet()
  47. # 定义损失函数和优化器
  48. criterion = nn.MSELoss()
  49. optimizer = optim.Adam(model.parameters(), lr=0.001)
  50. train = []
  51. test = []
  52. dataall = pd.read_csv('data.csv')
  53. dataall = dataall.iloc[:, 1:4]
  54. dataall = dataall.sample(frac=1)
  55. train = dataall.iloc[0:300]
  56. train = train.reset_index(drop=True)
  57. test = dataall.iloc[300:]
  58. test = test.reset_index(drop=True)
  59. nlossLast = 0
  60. for i in range(5):
  61. nloss = 0
  62. for k in range(len(train)):
  63. va = getem(train.iloc[k]['question'])
  64. vb = getem(train.iloc[k]['answer'])
  65. if train.iloc[k]['answer'] not in answerCache:
  66. answerCache.append(train.iloc[k]['answer'])
  67. trainTensor = torch.Tensor([va, vb]).reshape([1, 2, len(va)])
  68. output = model(trainTensor)
  69. # 计算损失
  70. if train.iloc[k]['label'] == 1:
  71. loss = criterion(output,
  72. torch.tensor([1, 0]).float().reshape([1, 2]))
  73. else:
  74. loss = criterion(output,
  75. torch.tensor([0, 1]).float().reshape([1, 2]))
  76. # 反向传播并更新权重
  77. optimizer.zero_grad()
  78. loss.backward()
  79. optimizer.step()
  80. nloss += loss
  81. if k % 50 == 0:
  82. print(i, k, 'done')
  83. print('one loop done', nloss/len(train))
  84. p = 0
  85. for i in range(len(test)):
  86. va = getem(test.iloc[i]['question'])
  87. Scores = np.zeros(len(answerCache))
  88. for j in range(len(answerCache)):
  89. vb = getem(answerCache[j])
  90. testTensor = torch.Tensor([va, vb]).reshape([1, 2, len(va)])
  91. output = model(testTensor)
  92. Scores[j] = output[0][0]
  93. for k in range(2):
  94. if test.iloc[i]['label'] == 1:
  95. vc = test.iloc[i]['answer']
  96. else:
  97. vc = ''
  98. tt = Scores.argmax()
  99. if Scores[tt] > 0.5:
  100. vb = answerCache[tt]
  101. Scores[tt] = -1
  102. else:
  103. vb = ''
  104. if vb == vc:
  105. p += 1
  106. break
  107. print(p/len(test))