I am trying to implement early stopping and the loop does not break I have tried nested(continue, break statements). the code initializes the loop but it doesn’t break the loop and early stopping doesn’t stop the training.
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np
from matplotlib import pyplot as plt
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
train_loader = torch.utils.data.DataLoader(dsets.MNIST('/Users/manishrai/Downloads', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
dsets.MNIST('/Users/manishrai/Downloads', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_test, shuffle=True)
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
print(fig)
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMModel, self).__init__()
# Hidden dimensions
self.hidden_dim = hidden_dim
# Number of hidden layers
self.layer_dim = layer_dim
# Building your LSTM
# batch_first=True causes input/output tensors to be of shape
# (batch_dim, seq_dim, feature_dim)
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
# Readout layer
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# Initialize hidden state with zeros
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# Initialize cell state
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# 28 time steps
# We need to detach as we are doing truncated backpropagation through time (BPTT)
# If we don't, we'll backprop all the way to the start even after going through another batch
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
# Index hidden state of last time step
# out.size() --> 100, 28, 100
# out[:, -1, :] --> 100, 100 --> just want last time step hidden states!
out = self.fc(out[:, -1, :])
# out.size() --> 100, 10
return out
input_dim = 28
hidden_dim = 100
layer_dim = 1
output_dim = 10
batch_size = 100
n_iters = 3000
min_val_loss = np.Inf
#num_epochs = n_iters / (len(train_loader) / batch_size)
num_epochs = 5
n_epochs_stop = 5
epochs_no_improve = 0
model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Number of steps to unroll
seq_dim = 28
iter = 0
for epoch in range(num_epochs):
val_loss = 0
for i, (images, labels) in enumerate(train_loader):
# Load images as a torch tensor with gradient accumulation abilities
images = images.view(-1, seq_dim, input_dim).requires_grad_()
# Clear gradients w.r.t. parameters
optimizer.zero_grad()
# Forward pass to get output/logits
# outputs.size() --> 100, 10
outputs = model(images)
# Calculate Loss: softmax --> cross entropy loss
loss = criterion(outputs, labels)
# Getting gradients w.r.t. parameters
loss.backward()
# Updating parameters
optimizer.step()
val_loss += loss
val_loss = val_loss / len(train_loader)
# If the validation loss is at a minimum
if val_loss < min_val_loss:
# Save the model
#torch.save(model)
epochs_no_improve = 0
min_val_loss = val_loss
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!' )
break
iter += 1
if iter % 500 == 0:
# Calculate Accuracy
correct = 0
total = 0
# Iterate through test dataset
for images, labels in test_loader:
# Resize images
images = images.view(-1, seq_dim, input_dim)
# Forward pass only to get logits/output
outputs = model(images)
# Get predictions from the maximum value
_, predicted = torch.max(outputs.data, 1)
# Total number of labels
total += labels.size(0)
# Total correct predictions
correct += (predicted == labels).sum()
accuracy = 100 * correct / total
# Print Loss
print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))
Figure(432x288)
Iteration: 500. Loss: 1.1461071968078613. Accuracy: 59
Iteration: 1000. Loss: 0.2487064003944397. Accuracy: 89
Iteration: 1500. Loss: 0.18170203268527985. Accuracy: 92
Iteration: 2000. Loss: 0.262026309967041. Accuracy: 95
Iteration: 2500. Loss: 0.09644989669322968. Accuracy: 96
Early stopping!
Iteration: 3000. Loss: 0.27113696932792664. Accuracy: 96
Iteration: 3500. Loss: 0.15842244029045105. Accuracy: 97
I want to stop this code after early stopping. I have tried nested looping to check if the outer loop breaks and it doesnt work. What do I do?