train_lstm.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. # Device configuration
  6. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  7. # Hyper-parameters
  8. sequence_length = 28
  9. input_size = 28
  10. hidden_size = 128
  11. num_layers = 2
  12. num_classes = 10
  13. batch_size = 100
  14. num_epochs = 2
  15. learning_rate = 0.01
  16. # MNIST dataset
  17. train_dataset = torchvision.datasets.MNIST(root='/data/lstm_data',
  18. train=True,
  19. transform=transforms.ToTensor(),
  20. download=True)
  21. test_dataset = torchvision.datasets.MNIST(root='/data/lstm_data',
  22. train=False,
  23. transform=transforms.ToTensor())
  24. # Data loader
  25. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  26. batch_size=batch_size,
  27. shuffle=True)
  28. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  29. batch_size=batch_size,
  30. shuffle=False)
  31. # Recurrent neural network (many-to-one)
  32. class RNN(nn.Module):
  33. def __init__(self, input_size, hidden_size, num_layers, num_classes):
  34. super(RNN, self).__init__()
  35. self.hidden_size = hidden_size
  36. self.num_layers = num_layers
  37. self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
  38. self.fc = nn.Linear(hidden_size, num_classes)
  39. def forward(self, x):
  40. # 初始化的隐藏元和记忆元,通常它们的维度是一样的
  41. h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # x.size(0)是batch_size
  42. c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
  43. # Forward propagate LSTM
  44. out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
  45. # Decode the hidden state of the last time step
  46. out = self.fc(out[:, -1, :])
  47. return out
  48. model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
  49. # Loss and optimizer
  50. criterion = nn.CrossEntropyLoss()
  51. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  52. # Train the model
  53. total_step = len(train_loader)
  54. for epoch in range(num_epochs):
  55. for i, (images, labels) in enumerate(train_loader):
  56. images = images.reshape(-1, sequence_length, input_size).to(device)
  57. labels = labels.to(device)
  58. print('size', images.shape)
  59. # Forward pass
  60. outputs = model(images)
  61. print(outputs.size())
  62. loss = criterion(outputs, labels)
  63. # Backward and optimize
  64. optimizer.zero_grad()
  65. loss.backward()
  66. optimizer.step()
  67. if (i + 1) % 100 == 0:
  68. print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
  69. .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
  70. # Test the model
  71. with torch.no_grad():
  72. correct = 0
  73. total = 0
  74. for images, labels in test_loader:
  75. images = images.reshape(-1, sequence_length, input_size).to(device)
  76. labels = labels.to(device)
  77. outputs = model(images)
  78. _, predicted = torch.max(outputs.data, 1)
  79. total += labels.size(0)
  80. correct += (predicted == labels).sum().item()
  81. print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
  82. # Save the model checkpoint
  83. torch.save(model.state_dict(), 'model.ckpt')