Accuracy of the model drastically drops

I am training a model on MNIST and during the training the accuracy of the model drops drastically to 10%.

Epoch: 37
468 469 Loss: 0.078 | Acc: 97.338% (58403/60000)
78 79 Loss: 0.046 | Acc: 98.860% (9886/10000)
Best Test Score tensor(98.8800)
Epoch:  38

Epoch: 38
468 469 Loss: 0.068 | Acc: 97.795% (58677/60000)
78 79 Loss: 0.053 | Acc: 98.780% (9878/10000)
Best Test Score tensor(98.8800)
Epoch:  39

Epoch: 39
468 469 Loss: 0.418 | Acc: 85.373% (51224/60000)
78 79 Loss: 2.294 | Acc: 12.560% (1356/10000)
Best Test Score tensor(98.8800)
Epoch:  40

Epoch: 40
468 469 Loss: 2.302 | Acc: 11.133% (6680/60000)
78 79 Loss: 2.301 | Acc: 10.610% (1061/10000)
Best Test Score tensor(98.8800)
=====> Saving checkpoint...
Epoch:  41

Epoch: 41
468 469 Loss: 2.304 | Acc: 11.155% (6693/60000)
78 79 Loss: 2.303 | Acc: 9.590% (959/10000)
Best Test Score tensor(98.8800)

I searched through and realized this is a bug with higher version of Pytorch.
After disabling cudnn of BN layer, it works if I don’t feed forward with the testset, although it slows down the training by about 1.5x time. Is there any work around to solve the problem?

The linked issue is from early-mid 2018 and seems to deal with PyTorch 0.4.
Could you post a code snippet to reproduce this issue in the latest stable version (1.5)?

from __future__ import print_function

import argparse
import numpy as np
import os
import csv
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as Data

from torch.optim import SGD, lr_scheduler
from collections import OrderedDict


# Checkpoint related
START_EPOCH = 0

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

'''
Usage:
python 
'''

class SmallCNN(nn.Module):
    def __init__(self, drop=0.5):
        super(SmallCNN, self).__init__()

        self.num_channels = 1
        self.num_labels = 10

        activ = nn.ReLU(True)

        self.feature_extractor = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(self.num_channels, 32, 3)),
            ('relu1', activ),
            ('conv2', nn.Conv2d(32, 32, 3)),
            ('relu2', activ),
            ('maxpool1', nn.MaxPool2d(2, 2)),
            ('conv3', nn.Conv2d(32, 64, 3)),
            ('relu3', activ),
            ('conv4', nn.Conv2d(64, 64, 3)),
            ('relu4', activ),
            ('maxpool2', nn.MaxPool2d(2, 2)),
        ]))

        self.classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(64 * 4 * 4, 200)),
            ('relu1', activ),
            ('drop', nn.Dropout(drop)),
            ('fc2', nn.Linear(200, 200)),
            ('relu2', activ),
            ('fc3', nn.Linear(200, self.num_labels)),
        ]))

        for m in self.modules():
            if isinstance(m, (nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        nn.init.constant_(self.classifier.fc3.weight, 0)
        nn.init.constant_(self.classifier.fc3.bias, 0)

    def forward(self, input, with_latent=False, fake_relu=False, no_relu=False):
        features = self.feature_extractor(input)
        logits = self.classifier(features.view(-1, 64 * 4 * 4))
        return logits



class AttackPGD(nn.Module):
    """Adversarial training with PGD.
    """
    def __init__(self, model, config):
        super(AttackPGD, self).__init__()
        self.model = model
        self.rand = config['random_start']
        self.step_size = config['step_size']
        self.epsilon = config['epsilon']
        self.num_steps = config['num_steps']
        
        

    def forward(self, inputs, target, make_adv=False):
       x = inputs.detach()
       if make_adv:
        #step = LinfStep(self.epsilon, self.step_size)
        if self.rand:
            x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)       
        prev_training = bool(self.training)
        self.eval()
        for i in range(self.num_steps):  
             
           self.eval()
           x = x.clone().detach().requires_grad_(True)
           outputs = self.model(normalize(x).to(device))
           losses = criterion(outputs, target)
           loss = torch.mean(losses) 
           grad, = torch.autograd.grad(loss, [x])
           with torch.no_grad():
              step = torch.sign(grad) * self.step_size
              diff = x + step - inputs
              diff = torch.clamp(diff, -self.epsilon, self.epsilon)
              x = torch.clamp(diff + inputs, 0, 1)
        
        
        output = self.model(normalize(x.clone().detach()).to(device))
        if prev_training:
           self.train()
       return output, x


def train_glist(epoch):
  criterion = nn.CrossEntropyLoss()
  train_loss, correct, total = [[0]*len(netlist) for _ in range(3)]
  for idn in range(len(netlist)):
    print('\nEpoch: %d' % epoch)
    netlist[idn].train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        
        # Forward pass
        outputs, final_inp = netlist[idn](inputs, targets, make_adv=True)
        loss = criterion(outputs, targets).mean() 
        if len(loss.shape) > 0: loss = loss.mean()      
        # Backward and optimize
        optimizerlist[idn].zero_grad()
        loss.backward()
        optimizerlist[idn].step()
        with torch.no_grad():
           train_loss[idn] += loss.item()
           _, pred_idx = torch.max(outputs.data, 1)
           total[idn] += targets.size(0)
           correct[idn] += pred_idx.eq(targets.data).cpu().sum().float()


    print(batch_idx, len(train_loader),
                     'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss[idn]/(batch_idx+1), 100.*correct[idn]/total[idn], correct[idn], total[idn]))
        
    schedulelist[idn].step()
        

  return [1./(batch_idx+1)*t for t in train_loss], [100./total[i]*correct[i] for i in range(len(correct))]


def pgdattack(model, inputs, targets, epsilon=8 / 255., step_size=2.0 / 255, num_steps=7, rand = False):

       x = inputs.detach()
       prev_training = bool(model.training)
       model.eval()
       for i in range(num_steps):
           x = x.clone().detach().requires_grad_(True)
           outputs = model.model(normalize(x))
           losses = criterion(outputs, targets)
           loss = torch.mean(losses)  
           grad, = torch.autograd.grad(loss, [x])
           with torch.no_grad():
              step = torch.sign(grad) * step_size
              diff = x + step - inputs
              diff = torch.clamp(diff, -epsilon, epsilon)
              x = torch.clamp(diff + inputs, 0, 1)

       output = model.model(normalize(x.clone().detach()))
       if prev_training:
           model.train()
       return output, x



def testlist(epoch, idn):
    criterion = nn.CrossEntropyLoss()
    netlist[idn].eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs, final_inp = pgdattack(netlist[idn], inputs, targets)
        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, pred_idx = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += pred_idx.eq(targets.data).cpu().sum().float()

    print(batch_idx, len(test_loader),
                     'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    return test_loss/batch_idx, 100.*correct/total




def load_model(arg_model, model_path, arg_dict=1):

       
        model = globals()[arg_model]().to(device)


        import dill
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'), pickle_module=dill)
        print(checkpoint.keys())
        #print(checkpoint['model'].keys())
        
        
        state_dict_path = 'model'
        if not ('model' in checkpoint):
           state_dict_path = 'state_dict'
        if ('net' in checkpoint):
           checkpoint['model'] = checkpoint['net']
           del checkpoint['net']
        
        if arg_dict:
           if 'model' in checkpoint:
              if hasattr(checkpoint['model'], 'state_dict'):
                 print("Hi ehsan*******************")
                 sd = checkpoint['model'].state_dict()
              else:
                 sd = checkpoint['model']
           elif 'state_dict' in checkpoint:
              sd = checkpoint['state_dict']
              print ('epoch', checkpoint['epoch'],
                     'arch', checkpoint['arch'],
                     'nat_prec1', checkpoint['nat_prec1'], 
                     'adv_prec1', checkpoint['adv_prec1'])
           else:
              sd = checkpoint 
              print(sd.keys())
           
           sd = {k.replace('module.attacker.model.', '').replace('module.model.','').replace('module.','').replace('model.',''):v for k,v in sd.items()}
        
           keys = model.state_dict().keys()
           new_state = {}
           for k in sd.keys():
              if k in keys:
                 new_state[k] = sd[k]
              else:
                 print(k)
        
           model.load_state_dict(new_state)
        else:
           model = checkpoint['model']
        checkpoint = None
        sd = None
                
        model.eval().to(device)

        return model


if __name__ == '__main__':

    # Data
    print('=====> Preparing data...')

    transform_test = transforms.Compose([transforms.ToTensor(),])
    trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform_test)
    testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform_test)
        

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=False, num_workers=2)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

    unnormalize = lambda x: x
    normalize = lambda x: x

    
    #Configuration for Attack Model
    config = {
                'epsilon': 0.3,
                'step_size': 0.01,
                'random_start': False,
                'loss_func': 'xent',
                'num_steps': 40
               } 
    #Building the Models
    print('=====> Building model...')
    model_path=['checkpoint.pt.best']
    source_models=['SmallCNN']
    netlist = []
    for i, arg_model in enumerate(source_models):
       net = load_model(source_models[i], model_path[i])
       model0 = AttackPGD(net, config)
       model0 = model0.to(device) 
       netlist += [model0]


    if torch.cuda.device_count() > 1:
        print("=====> Use", torch.cuda.device_count(), "GPUs")
        for idn in range(len(netlist)):
            netlist[idn] = nn.DataParallel(netlist[idn])

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()

    # Hard-coded base parameters
    train_args =  {
            "lr": 0.1,
            "weight_decay": 1e-06,
            "momentum": 0.9,
            "step_lr": 50,
            "epoch": 200}


    optimizerlist, schedulelist = [], []
    best_acc = 0
    for idn in range(len(netlist)):
        # Make optimizer
        param_list = netlist[idn].parameters()
        optimizer = SGD(param_list, train_args["lr"], train_args["momentum"],
                                weight_decay=train_args["weight_decay"])

        schedule = lr_scheduler.StepLR(optimizer, step_size=train_args["step_lr"])
        optimizerlist += [optimizer]
        schedulelist += [schedule]
    for epoch in range(START_EPOCH, train_args["epoch"]):
         print("Epoch: ", epoch)
         train_loss, train_acc = train_glist(epoch)
         test_loss, test_acc = testlist(epoch, 0)

link for the checkpoint

checkpoint

Output:

Epoch: 33
468 469 Loss: 0.081 | Acc: 97.318% (58391/60000)
78 79 Loss: 0.051 | Acc: 98.800% (9880/10000)
Epoch:  34

Epoch: 34
468 469 Loss: 0.077 | Acc: 97.445% (58467/60000)
78 79 Loss: 0.048 | Acc: 98.860% (9886/10000)
Epoch:  35

Epoch: 35
468 469 Loss: 0.073 | Acc: 97.585% (58551/60000)
78 79 Loss: 0.056 | Acc: 98.770% (9877/10000)
Epoch:  36

Epoch: 36
468 469 Loss: 0.062 | Acc: 97.852% (58711/60000)
78 79 Loss: 0.057 | Acc: 98.750% (9875/10000)
Epoch:  37

Epoch: 37
468 469 Loss: 0.078 | Acc: 97.338% (58403/60000)
78 79 Loss: 0.046 | Acc: 98.860% (9886/10000)
Epoch:  38

Epoch: 38
468 469 Loss: 0.068 | Acc: 97.795% (58677/60000)
78 79 Loss: 0.053 | Acc: 98.780% (9878/10000)
Epoch:  39

Epoch: 39
468 469 Loss: 0.418 | Acc: 85.373% (51224/60000)
78 79 Loss: 2.294 | Acc: 12.560% (1356/10000)
Epoch:  40

Epoch: 40
468 469 Loss: 2.302 | Acc: 11.133% (6680/60000)
78 79 Loss: 2.301 | Acc: 10.610% (1061/10000)
Epoch:  41

Epoch: 41
468 469 Loss: 2.304 | Acc: 11.155% (6693/60000)
78 79 Loss: 2.303 | Acc: 9.590% (959/10000)
Epoch:  42

Epoch: 42
468 469 Loss: 2.303 | Acc: 10.870% (6522/60000)
78 79 Loss: 2.302 | Acc: 10.280% (1028/10000)
Epoch:  43

Epoch: 43
468 469 Loss: 2.302 | Acc: 11.327% (6796/60000)
78 79 Loss: 2.301 | Acc: 10.510% (1051/10000)

This is the output from another run:

Epoch: 47
468 469 Loss: 0.061 | Acc: 97.937% (58762/60000)
78 79 Loss: 0.059 | Acc: 98.710% (9871/10000)
Best Test Score tensor(98.8800)
Epoch: 48

Epoch: 48
468 469 Loss: 1.639 | Acc: 38.638% (23183/60000)
78 79 Loss: 2.302 | Acc: 10.280% (1028/10000)
Best Test Score tensor(98.8800)
Epoch: 49

Epoch: 49
468 469 Loss: 2.302 | Acc: 11.065% (6639/60000)
78 79 Loss: 2.302 | Acc: 10.280% (1028/10000)
Best Test Score tensor(98.8800)

Thanks for the update. Based on the printed stats it seems you are overtraining your model, as the accuracy is noisy around 97.2%, so the updates might knock the model out of the “good” parameter space.
You could try to lower the learning rate of your optimizer(s) or use an adaptive optimizer such as Adam.

Also, were you able to reproduce this effect without using cudnn?

Not actually, the problem was with high learning rate.
Thanks for pointing out the issue.