Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

86 wiersze
2.9 KiB

  1. import sys
  2. import configparser
  3. import pandas as pd
  4. import torch
  5. from torch import nn
  6. import torch.optim as optim
  7. def main(**param):
  8. L ={"dashscope":1536, "liandong":1024, "zhipu":1024}
  9. answerCache = []
  10. config = configparser.ConfigParser()
  11. config.read("settings.ini", encoding="utf-8")
  12. modelName = param['model']
  13. embeddings = param['embedding']
  14. mode = param['mode']
  15. print(modelName, embeddings, mode)
  16. emName = embeddings + 'embedding'
  17. embedding_config = dict(config.items(emName))
  18. ems = __import__('embeddings.%s' % emName,
  19. fromlist=['embeddings'])
  20. string = 'ems.' + emName.capitalize()
  21. embedding = eval(string)(**embedding_config)
  22. mds = __import__('models.%s' % modelName,
  23. fromlist=['models'])
  24. string = 'mds.' + modelName.capitalize()
  25. model = eval(string)(L[embeddings])
  26. criterion = nn.MSELoss()
  27. optimizer = optim.Adam(model.parameters(), lr=0.001)
  28. if mode == 'train':
  29. train = pd.read_csv('data/train.csv')
  30. for i in range(20):
  31. nloss = 0
  32. for k in range(len(train)):
  33. va = embedding.getem(train.iloc[k]['question'])
  34. vb = embedding.getem(train.iloc[k]['answer'])
  35. if train.iloc[k]['answer'] not in answerCache:
  36. answerCache.append(train.iloc[k]['answer'])
  37. trainTensor = model.prosess(va, vb)
  38. output = model(trainTensor)
  39. # 计算损失
  40. predict = torch.tensor(train.iloc[k]['label']).float()
  41. predict = predict.reshape([1, 1])
  42. loss = criterion(output,
  43. predict)
  44. # 反向传播并更新权重
  45. optimizer.zero_grad()
  46. loss.backward()
  47. optimizer.step()
  48. nloss += loss
  49. if k % 50 == 0:
  50. print(i, k, 'done')
  51. print('one loop done', nloss/len(train))
  52. torch.save(model, 'models/%s.pth' % modelName)
  53. if mode == 'test':
  54. n = 0
  55. model = torch.load('models/%s.pth' % modelName)
  56. model.eval()
  57. test = pd.read_csv('data/test.csv')
  58. for i in range(len(test)):
  59. va = embedding.getem(test.iloc[i]['question'])
  60. vb = embedding.getem(test.iloc[i]['answer'])
  61. testTensor = model.prosess(va, vb)
  62. output = model(testTensor)
  63. if output > 0.5 and test.iloc[i]['label'] == 1:
  64. n += 1
  65. if output < 0.5 and test.iloc[i]['label'] == 0:
  66. n += 1
  67. print(n/len(test))
  68. if __name__ == '__main__':
  69. if not len(sys.argv) == 4:
  70. arg1 = 'cnn'
  71. arg2 = 'dashscope'
  72. arg3 = 'train'
  73. else:
  74. # 从命令行参数中获取参数值
  75. arg1 = sys.argv[1]
  76. arg2 = sys.argv[2]
  77. arg3 = sys.argv[3]
  78. main(model=arg1, embedding=arg2, mode=arg3)