Early Stopping - Loop does not 'break' training continues

I am trying to implement early stopping and the loop does not break I have tried nested(continue, break statements). the code initializes the loop but it doesn’t break the loop and early stopping doesn’t stop the training.

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np
from matplotlib import pyplot as plt

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

train_loader = torch.utils.data.DataLoader(dsets.MNIST('/Users/manishrai/Downloads', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  dsets.MNIST('/Users/manishrai/Downloads', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

import matplotlib.pyplot as plt

fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
print(fig)


class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim

        # Number of hidden layers
        self.layer_dim = layer_dim

        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)

        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        # 28 time steps
        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))

        # Index hidden state of last time step
        # out.size() --> 100, 28, 100
        # out[:, -1, :] --> 100, 100 --> just want last time step hidden states! 
        out = self.fc(out[:, -1, :]) 
        # out.size() --> 100, 10
        return out

input_dim = 28
hidden_dim = 100
layer_dim = 1
output_dim = 10
batch_size = 100
n_iters = 3000
min_val_loss = np.Inf
#num_epochs = n_iters / (len(train_loader) / batch_size)
num_epochs = 5
n_epochs_stop = 5
epochs_no_improve = 0
model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)

criterion = nn.CrossEntropyLoss()

learning_rate = 0.1

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  

# Number of steps to unroll
seq_dim = 28  

iter = 0
for epoch in range(num_epochs):
    val_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        # Load images as a torch tensor with gradient accumulation abilities
        images = images.view(-1, seq_dim, input_dim).requires_grad_()

        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()

        # Forward pass to get output/logits
        # outputs.size() --> 100, 10
        outputs = model(images)

        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)

        # Getting gradients w.r.t. parameters
        loss.backward()

        # Updating parameters
        optimizer.step()
        
        val_loss += loss
        val_loss = val_loss / len(train_loader)
        # If the validation loss is at a minimum
        if val_loss < min_val_loss:
  # Save the model
             #torch.save(model)
             epochs_no_improve = 0
             min_val_loss = val_loss
  
        else:
            epochs_no_improve += 1
  # Check early stopping condition
        if epochs_no_improve == n_epochs_stop:
            print('Early stopping!' )
            break
        
        
        iter += 1

        if iter % 500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                # Resize images
                images = images.view(-1, seq_dim, input_dim)

                # Forward pass only to get logits/output
                outputs = model(images)

                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)

                # Total number of labels
                total += labels.size(0)

                # Total correct predictions
                correct += (predicted == labels).sum()

            accuracy = 100 * correct / total

            # Print Loss
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))
    
       
Figure(432x288)
Iteration: 500. Loss: 1.1461071968078613. Accuracy: 59
Iteration: 1000. Loss: 0.2487064003944397. Accuracy: 89
Iteration: 1500. Loss: 0.18170203268527985. Accuracy: 92
Iteration: 2000. Loss: 0.262026309967041. Accuracy: 95
Iteration: 2500. Loss: 0.09644989669322968. Accuracy: 96
Early stopping!
Iteration: 3000. Loss: 0.27113696932792664. Accuracy: 96
Iteration: 3500. Loss: 0.15842244029045105. Accuracy: 97

I want to stop this code after early stopping. I have tried nested looping to check if the outer loop breaks and it doesnt work. What do I do?

I was able to stop the use by using break statement outside the ‘loop’. Please consider this question as ‘Resolved’ .

# Check early stopping condition
        if epochs_no_improve == n_epochs_stop:
            print('Early stopping!' )
            early_stop = True
            break
        else:
            continue
        break
    if early_stop:
        print("Stopped")
        break

Thanks,
Akhilesh