train_main.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import torch
  2. from torch import nn as nn
  3. from model import BiRNN
  4. import os
  5. from userdataset import UserDataset
  6. import time
  7. import torch.optim as optim
  8. from torch.utils.data import DataLoader, TensorDataset
  9. from torch.autograd import Variable
  10. from config import Config
  11. from torch.utils.data import DataLoader, WeightedRandomSampler
  12. from torchsampler import ImbalancedDatasetSampler
  13. import random
  14. import pandas as pd
  15. from tqdm import tqdm
  16. from sklearn import metrics
  17. import sys
  18. from datetime import datetime
  19. import pickle
  20. sys.path.append('../')
  21. from common.log_utils import logFactory
  22. logger = logFactory("train_main").log
  23. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  24. def train(train_loader, test_loader):
  25. model = BiRNN(input_size=Config.input_size, hidden_dim=Config.hidden_dim, label_size=Config.label_size,
  26. batch_size=Config.batch_size, num_layers=Config.num_layers, use_gpu=Config.use_gpu)
  27. if Config.use_gpu:
  28. model = model.cuda()
  29. loss_function = nn.CrossEntropyLoss().to(device)
  30. else:
  31. model = model
  32. loss_function = nn.CrossEntropyLoss()
  33. optimizer = optim.SGD(model.parameters(), lr=Config.learning_rate)
  34. train_loss_ = []
  35. test_loss_ = []
  36. train_acc_ = []
  37. test_acc_ = []
  38. logger.info(f"开始epoch,共{Config.epochs}轮")
  39. for epoch in range(Config.epochs):
  40. optimizer, lr = Config.adjust_learning_rate(optimizer, epoch)
  41. logger.info(f'epoch-{epoch},当前lr为:{str(lr)}')
  42. # train epoch
  43. total_acc = 0.0
  44. total_loss = 0.0
  45. total = 0.0
  46. for index, (contents, labels) in tqdm(enumerate(train_loader)):
  47. try:
  48. train_inputs = contents
  49. train_labels = torch.squeeze(labels)
  50. if Config.use_gpu:
  51. # train_inputs, train_labels = Variable(train_inputs, train_labels)
  52. train_inputs, train_labels = train_inputs.cuda(), train_labels.cuda()
  53. else:
  54. train_inputs = Variable(train_inputs)
  55. model.zero_grad()
  56. model.batch_size = len(train_labels)
  57. model.hidden = model.init_hidden()
  58. output = model(train_inputs)
  59. loss = loss_function(output, Variable(train_labels))
  60. loss.backward()
  61. optimizer.step()
  62. # calc training acc
  63. _, predicted = torch.max(output.data, 1)
  64. total_acc += (predicted == train_labels).sum()
  65. total += len(train_labels)
  66. total_loss += loss.item()
  67. if index % 64 == 0:
  68. logger.info('Epoch ' + str(epoch) + ' : ' + str(index // 10) + ' , LOSS =' + str(loss))
  69. result = metrics.classification_report(train_labels.cpu().numpy(), predicted.cpu().numpy())
  70. logger.info('in train epoch')
  71. logger.info(result)
  72. except Exception as e:
  73. print(e)
  74. continue
  75. train_loss_.append(total_loss / total)
  76. train_acc_.append(total_acc / total)
  77. # test epoch
  78. total_acc = 0.0
  79. total_loss = 0.0
  80. total = 0.0
  81. for index, (contents, labels) in enumerate(test_loader):
  82. try:
  83. test_inputs = contents
  84. test_labels = torch.squeeze(labels)
  85. if Config.use_gpu:
  86. test_inputs, test_labels = Variable(test_inputs.cuda()), test_labels.cuda()
  87. else:
  88. test_inputs = Variable(test_inputs)
  89. model.batch_size = len(test_labels)
  90. model.hidden = model.init_hidden()
  91. output = model(test_inputs)
  92. test_labels.to(torch.int64)
  93. loss = loss_function(output, Variable(test_labels))
  94. # calc testing acc
  95. _, predicted = torch.max(output.data, 1)
  96. total_acc += (predicted == test_labels).sum()
  97. total += len(test_labels)
  98. total_loss += loss.item()
  99. if index % 64 == 0:
  100. logger.info('Epoch ' + str(epoch) + ' : ' + str(index // 10) + ' , LOSS =' + str(loss))
  101. result = metrics.classification_report(test_labels.cpu().numpy(), predicted.cpu().numpy())
  102. logger.info('in test epoch')
  103. logger.info(result)
  104. except Exception as e:
  105. print(e)
  106. continue
  107. test_loss_.append(total_loss / total)
  108. test_acc_.append(total_acc / total)
  109. # result = metrics.classification_report(test_y_data.values, test_result)
  110. # logger.info(result)
  111. logger.info('[Epoch: %3d/%3d] Training Loss: %.3f, Testing Loss: %.3f, Training Acc: %.3f, Testing Acc: %.3f'
  112. % (epoch, Config.epochs, train_loss_[epoch], test_loss_[epoch], train_acc_[epoch], test_acc_[epoch]))
  113. result = {}
  114. result['train loss'] = train_loss_
  115. result['test loss'] = test_loss_
  116. result['train acc'] = train_acc_
  117. result['test acc'] = test_acc_
  118. if Config.use_plot:
  119. import PlotFigure as PF
  120. PF.PlotFigure(result, Config.plot_save)
  121. if Config.use_save:
  122. torch.save(model,"./model")
  123. def start():
  124. train_data_set = UserDataset("../data/csv_lstm_data/final_train_dict.txt","train_data")
  125. train_loader = DataLoader(dataset=train_data_set, batch_size=Config.batch_size, shuffle=Config.shuffle_train_data,
  126. sampler=ImbalancedDatasetSampler(train_data_set))
  127. # train_loader = DataLoader(dataset=train_data_set, batch_size=Config.batch_size, shuffle=Config.shuffle_train_data,
  128. # num_workers=8)
  129. test_data_set = UserDataset("../data/csv_lstm_data/final_test_dict.txt","test_data")
  130. test_loader = DataLoader(dataset=test_data_set, batch_size=Config.batch_size, shuffle=Config.shuffle_train_data,
  131. sampler=ImbalancedDatasetSampler(test_data_set))
  132. # test_loader = DataLoader(dataset=test_data_set, batch_size=Config.batch_size, shuffle=Config.shuffle_train_data,
  133. # num_workers=8)
  134. train(train_loader, test_loader)
  135. if __name__ == "__main__":
  136. start()
  137. # f = open("../data/csv_lstm_data/path_dict.txt", "r")
  138. # lines = f.read().split("\n")
  139. # label_dict = {}
  140. # pos = []
  141. # neg = []
  142. # for line in lines:
  143. # if line == "":
  144. # continue
  145. # file_path = line.split(" ")[0]
  146. # label_value = line.split(" ")[1]
  147. # if label_value == "0":
  148. # pos.append(file_path)
  149. # else:
  150. # neg.append(file_path)
  151. # pos_test = random.sample(pos, int(len(pos) / 5))
  152. # pos_train = [elem for elem in pos if elem not in pos_test]
  153. #
  154. # neg_test = random.sample(neg, int(len(neg) / 5))
  155. # neg_train = [elem for elem in neg if elem not in neg_test]
  156. #
  157. #
  158. # with open("../data/csv_lstm_data/train_dict.txt", "a") as f:
  159. # for e in pos_train:
  160. # f.write(e + " " + "0" + "\n")
  161. # for e in neg_train:
  162. # f.write(e + " " + "1" + "\n")
  163. #
  164. # with open("../data/csv_lstm_data/test_dict.txt", "a") as f:
  165. # for e in pos_test:
  166. # f.write(e + " " + "0" + "\n")
  167. # for e in neg_test:
  168. # f.write(e + " " + "1" + "\n")
  169. #
  170. # pass
  171. # train_final = []
  172. # test_final = []
  173. # for i, path in enumerate(["../data/csv_lstm_data/train_dict.txt", "../data/csv_lstm_data/test_dict.txt"]):
  174. #
  175. # with open(path, "r") as f:
  176. # lines = f.read()
  177. # line = lines.split("\n")
  178. # for l in line:
  179. # path = "." + l.split(" ")[0]
  180. # label = l.split(" ")[1]
  181. # df = pd.read_pickle(path)
  182. # if len(df) == 2:
  183. # if i == 0:
  184. # train_final.append((path, label))
  185. # else:
  186. # test_final.append((path, label))
  187. #
  188. # with open('../data/csv_lstm_data/final_train_dict.txt', "a") as f:
  189. # for e in train_final:
  190. # f.write(e[0]+" "+str(e[1])+"\n")
  191. #
  192. # with open('../data/csv_lstm_data/final_test_dict.txt', "a") as f:
  193. # for e in test_final:
  194. # f.write(e[0]+" "+str(e[1])+"\n")