Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

86 рядки
2.9 KiB

  1. import sys
  2. import configparser
  3. import pandas as pd
  4. import torch
  5. from torch import nn
  6. import torch.optim as optim
  7. def main(**param):
  8. L ={"dashscope":1536, "liandong":1024, "zhipu":1024}
  9. answerCache = []
  10. config = configparser.ConfigParser()
  11. config.read("settings.ini", encoding="utf-8")
  12. modelName = param['model']
  13. embeddings = param['embedding']
  14. mode = param['mode']
  15. print(modelName, embeddings, mode)
  16. emName = embeddings + 'embedding'
  17. embedding_config = dict(config.items(emName))
  18. ems = __import__('embeddings.%s' % emName,
  19. fromlist=['embeddings'])
  20. string = 'ems.' + emName.capitalize()
  21. embedding = eval(string)(**embedding_config)
  22. mds = __import__('models.%s' % modelName,
  23. fromlist=['models'])
  24. string = 'mds.' + modelName.capitalize()
  25. model = eval(string)(L[embeddings])
  26. criterion = nn.MSELoss()
  27. optimizer = optim.Adam(model.parameters(), lr=0.001)
  28. if mode == 'train':
  29. train = pd.read_csv('data/train.csv')
  30. for i in range(20):
  31. nloss = 0
  32. for k in range(len(train)):
  33. va = embedding.getem(train.iloc[k]['question'])
  34. vb = embedding.getem(train.iloc[k]['answer'])
  35. if train.iloc[k]['answer'] not in answerCache:
  36. answerCache.append(train.iloc[k]['answer'])
  37. trainTensor = model.prosess(va, vb)
  38. output = model(trainTensor)
  39. # 计算损失
  40. predict = torch.tensor(train.iloc[k]['label']).float()
  41. predict = predict.reshape([1, 1])
  42. loss = criterion(output,
  43. predict)
  44. # 反向传播并更新权重
  45. optimizer.zero_grad()
  46. loss.backward()
  47. optimizer.step()
  48. nloss += loss
  49. if k % 50 == 0:
  50. print(i, k, 'done')
  51. print('one loop done', nloss/len(train))
  52. torch.save(model, 'models/%s.pth' % modelName)
  53. if mode == 'test':
  54. n = 0
  55. model = torch.load('models/%s.pth' % modelName)
  56. model.eval()
  57. test = pd.read_csv('data/test.csv')
  58. for i in range(len(test)):
  59. va = embedding.getem(test.iloc[i]['question'])
  60. vb = embedding.getem(test.iloc[i]['answer'])
  61. testTensor = model.prosess(va, vb)
  62. output = model(testTensor)
  63. if output > 0.5 and test.iloc[i]['label'] == 1:
  64. n += 1
  65. if output < 0.5 and test.iloc[i]['label'] == 0:
  66. n += 1
  67. print(n/len(test))
  68. if __name__ == '__main__':
  69. if not len(sys.argv) == 4:
  70. arg1 = 'cnn'
  71. arg2 = 'dashscope'
  72. arg3 = 'train'
  73. else:
  74. # 从命令行参数中获取参数值
  75. arg1 = sys.argv[1]
  76. arg2 = sys.argv[2]
  77. arg3 = sys.argv[3]
  78. main(model=arg1, embedding=arg2, mode=arg3)