Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.

86 righe
2.9 KiB

  1. import sys
  2. import configparser
  3. import pandas as pd
  4. import torch
  5. from torch import nn
  6. import torch.optim as optim
  7. def main(**param):
  8. L ={"dashscope":1536, "liandong":1024, "zhipu":1024}
  9. answerCache = []
  10. config = configparser.ConfigParser()
  11. config.read("settings.ini", encoding="utf-8")
  12. modelName = param['model']
  13. embeddings = param['embedding']
  14. mode = param['mode']
  15. print(modelName, embeddings, mode)
  16. emName = embeddings + 'embedding'
  17. embedding_config = dict(config.items(emName))
  18. ems = __import__('embeddings.%s' % emName,
  19. fromlist=['embeddings'])
  20. string = 'ems.' + emName.capitalize()
  21. embedding = eval(string)(**embedding_config)
  22. mds = __import__('models.%s' % modelName,
  23. fromlist=['models'])
  24. string = 'mds.' + modelName.capitalize()
  25. model = eval(string)(L[embeddings])
  26. criterion = nn.MSELoss()
  27. optimizer = optim.Adam(model.parameters(), lr=0.001)
  28. if mode == 'train':
  29. train = pd.read_csv('data/train.csv')
  30. for i in range(20):
  31. nloss = 0
  32. for k in range(len(train)):
  33. va = embedding.getem(train.iloc[k]['question'])
  34. vb = embedding.getem(train.iloc[k]['answer'])
  35. if train.iloc[k]['answer'] not in answerCache:
  36. answerCache.append(train.iloc[k]['answer'])
  37. trainTensor = model.prosess(va, vb)
  38. output = model(trainTensor)
  39. # 计算损失
  40. predict = torch.tensor(train.iloc[k]['label']).float()
  41. predict = predict.reshape([1, 1])
  42. loss = criterion(output,
  43. predict)
  44. # 反向传播并更新权重
  45. optimizer.zero_grad()
  46. loss.backward()
  47. optimizer.step()
  48. nloss += loss
  49. if k % 50 == 0:
  50. print(i, k, 'done')
  51. print('one loop done', nloss/len(train))
  52. torch.save(model, 'models/%s.pth' % modelName)
  53. if mode == 'test':
  54. n = 0
  55. model = torch.load('models/%s.pth' % modelName)
  56. model.eval()
  57. test = pd.read_csv('data/test.csv')
  58. for i in range(len(test)):
  59. va = embedding.getem(test.iloc[i]['question'])
  60. vb = embedding.getem(test.iloc[i]['answer'])
  61. testTensor = model.prosess(va, vb)
  62. output = model(testTensor)
  63. if output > 0.5 and test.iloc[i]['label'] == 1:
  64. n += 1
  65. if output < 0.5 and test.iloc[i]['label'] == 0:
  66. n += 1
  67. print(n/len(test))
  68. if __name__ == '__main__':
  69. if not len(sys.argv) == 4:
  70. arg1 = 'cnn'
  71. arg2 = 'dashscope'
  72. arg3 = 'train'
  73. else:
  74. # 从命令行参数中获取参数值
  75. arg1 = sys.argv[1]
  76. arg2 = sys.argv[2]
  77. arg3 = sys.argv[3]
  78. main(model=arg1, embedding=arg2, mode=arg3)