25'ten fazla konu seçemezsiniz Konular bir harf veya rakamla başlamalı, kısa çizgiler ('-') içerebilir ve en fazla 35 karakter uzunluğunda olabilir.

86 satır
2.9 KiB

  1. import sys
  2. import configparser
  3. import pandas as pd
  4. import torch
  5. from torch import nn
  6. import torch.optim as optim
  7. def main(**param):
  8. L ={"dashscope":1536, "liandong":1024, "zhipu":1024}
  9. answerCache = []
  10. config = configparser.ConfigParser()
  11. config.read("settings.ini", encoding="utf-8")
  12. modelName = param['model']
  13. embeddings = param['embedding']
  14. mode = param['mode']
  15. print(modelName, embeddings, mode)
  16. emName = embeddings + 'embedding'
  17. embedding_config = dict(config.items(emName))
  18. ems = __import__('embeddings.%s' % emName,
  19. fromlist=['embeddings'])
  20. string = 'ems.' + emName.capitalize()
  21. embedding = eval(string)(**embedding_config)
  22. mds = __import__('models.%s' % modelName,
  23. fromlist=['models'])
  24. string = 'mds.' + modelName.capitalize()
  25. model = eval(string)(L[embeddings])
  26. criterion = nn.MSELoss()
  27. optimizer = optim.Adam(model.parameters(), lr=0.001)
  28. if mode == 'train':
  29. train = pd.read_csv('data/train.csv')
  30. for i in range(20):
  31. nloss = 0
  32. for k in range(len(train)):
  33. va = embedding.getem(train.iloc[k]['question'])
  34. vb = embedding.getem(train.iloc[k]['answer'])
  35. if train.iloc[k]['answer'] not in answerCache:
  36. answerCache.append(train.iloc[k]['answer'])
  37. trainTensor = model.prosess(va, vb)
  38. output = model(trainTensor)
  39. # 计算损失
  40. predict = torch.tensor(train.iloc[k]['label']).float()
  41. predict = predict.reshape([1, 1])
  42. loss = criterion(output,
  43. predict)
  44. # 反向传播并更新权重
  45. optimizer.zero_grad()
  46. loss.backward()
  47. optimizer.step()
  48. nloss += loss
  49. if k % 50 == 0:
  50. print(i, k, 'done')
  51. print('one loop done', nloss/len(train))
  52. torch.save(model, 'models/%s.pth' % modelName)
  53. if mode == 'test':
  54. n = 0
  55. model = torch.load('models/%s.pth' % modelName)
  56. model.eval()
  57. test = pd.read_csv('data/test.csv')
  58. for i in range(len(test)):
  59. va = embedding.getem(test.iloc[i]['question'])
  60. vb = embedding.getem(test.iloc[i]['answer'])
  61. testTensor = model.prosess(va, vb)
  62. output = model(testTensor)
  63. if output > 0.5 and test.iloc[i]['label'] == 1:
  64. n += 1
  65. if output < 0.5 and test.iloc[i]['label'] == 0:
  66. n += 1
  67. print(n/len(test))
  68. if __name__ == '__main__':
  69. if not len(sys.argv) == 4:
  70. arg1 = 'cnn'
  71. arg2 = 'dashscope'
  72. arg3 = 'train'
  73. else:
  74. # 从命令行参数中获取参数值
  75. arg1 = sys.argv[1]
  76. arg2 = sys.argv[2]
  77. arg3 = sys.argv[3]
  78. main(model=arg1, embedding=arg2, mode=arg3)