Não pode escolher mais do que 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

86 linhas
2.9 KiB

  1. import sys
  2. import configparser
  3. import pandas as pd
  4. import torch
  5. from torch import nn
  6. import torch.optim as optim
  7. def main(**param):
  8. L ={"dashscope":1536, "liandong":1024, "zhipu":1024}
  9. answerCache = []
  10. config = configparser.ConfigParser()
  11. config.read("settings.ini", encoding="utf-8")
  12. modelName = param['model']
  13. embeddings = param['embedding']
  14. mode = param['mode']
  15. print(modelName, embeddings, mode)
  16. emName = embeddings + 'embedding'
  17. embedding_config = dict(config.items(emName))
  18. ems = __import__('embeddings.%s' % emName,
  19. fromlist=['embeddings'])
  20. string = 'ems.' + emName.capitalize()
  21. embedding = eval(string)(**embedding_config)
  22. mds = __import__('models.%s' % modelName,
  23. fromlist=['models'])
  24. string = 'mds.' + modelName.capitalize()
  25. model = eval(string)(L[embeddings])
  26. criterion = nn.MSELoss()
  27. optimizer = optim.Adam(model.parameters(), lr=0.001)
  28. if mode == 'train':
  29. train = pd.read_csv('data/train.csv')
  30. for i in range(20):
  31. nloss = 0
  32. for k in range(len(train)):
  33. va = embedding.getem(train.iloc[k]['question'])
  34. vb = embedding.getem(train.iloc[k]['answer'])
  35. if train.iloc[k]['answer'] not in answerCache:
  36. answerCache.append(train.iloc[k]['answer'])
  37. trainTensor = model.prosess(va, vb)
  38. output = model(trainTensor)
  39. # 计算损失
  40. predict = torch.tensor(train.iloc[k]['label']).float()
  41. predict = predict.reshape([1, 1])
  42. loss = criterion(output,
  43. predict)
  44. # 反向传播并更新权重
  45. optimizer.zero_grad()
  46. loss.backward()
  47. optimizer.step()
  48. nloss += loss
  49. if k % 50 == 0:
  50. print(i, k, 'done')
  51. print('one loop done', nloss/len(train))
  52. torch.save(model, 'models/%s.pth' % modelName)
  53. if mode == 'test':
  54. n = 0
  55. model = torch.load('models/%s.pth' % modelName)
  56. model.eval()
  57. test = pd.read_csv('data/test.csv')
  58. for i in range(len(test)):
  59. va = embedding.getem(test.iloc[i]['question'])
  60. vb = embedding.getem(test.iloc[i]['answer'])
  61. testTensor = model.prosess(va, vb)
  62. output = model(testTensor)
  63. if output > 0.5 and test.iloc[i]['label'] == 1:
  64. n += 1
  65. if output < 0.5 and test.iloc[i]['label'] == 0:
  66. n += 1
  67. print(n/len(test))
  68. if __name__ == '__main__':
  69. if not len(sys.argv) == 4:
  70. arg1 = 'cnn'
  71. arg2 = 'dashscope'
  72. arg3 = 'train'
  73. else:
  74. # 从命令行参数中获取参数值
  75. arg1 = sys.argv[1]
  76. arg2 = sys.argv[2]
  77. arg3 = sys.argv[3]
  78. main(model=arg1, embedding=arg2, mode=arg3)