Noisy MNIST gives to good results on clean test

I have the following function flow to add noise to the MNIST labels:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import random
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import sys
from sklearn.model_selection import StratifiedShuffleSplit

sns.set(rc={'figure.figsize': (12, 12)})


def add_noise(fraction, data, noise_perm):
    X_train, y_train = data.data.numpy(), data.targets.numpy()
    noise = noise_perm[y_train]

    _, noise_idx = next(iter(StratifiedShuffleSplit(n_splits=1,
                                                    test_size=fraction,
                                                    random_state=2020).split(X_train, y_train)))
    y_train[noise_idx] = noise[noise_idx]
    return torch.from_numpy(y_train.astype('int64'))


def validate_noise_level(noise_labels, clean_labels):
    print("Actual Noise Level:")
    print(1. - np.mean(noise_labels == clean_labels))
    return


def getMNISTData(batch_sizes, noise_frac=0.46, noise_perm=np.array([7, 9, 0, 4, 2, 1, 3, 5, 6, 8]), test_noise=False):
    data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                     torchvision.transforms.Normalize((0.1307,), (0.3081,))])

    # download the MNIST data
    train_ds = torchvision.datasets.MNIST('files/', train=True, download=True, transform=data_transform)
    test_ds = torchvision.datasets.MNIST('files/', train=False, download=True, transform=data_transform)
    # add noise to labels
    clean_labels = train_ds.targets.numpy().copy()
    train_ds.targets = add_noise(noise_frac, train_ds, noise_perm)
    validate_noise_level(train_ds.targets.numpy(), clean_labels)
    # train_ds.targets = addNoise(noise_frac, train_ds.targets.numpy(), noise_perm)
    train_ds, val_ds = torch.utils.data.random_split(train_ds, [50000, 10000])
    if test_noise:
        test_ds.targets = add_noise(noise_frac, test_ds, noise_perm)

    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_sizes['train'], shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_sizes['test'], shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_sizes['test'], shuffle=True)

    return train_loader, val_loader, test_loader

This creates the train/validation/test loader, where both the train and the validation has noisy labels.

My simple MLP is:

class ClassifierMLP(nn.Module):
    def __init__(self, in_size, hidden_sizes, num_classes, dropout_p=0.5):
        assert len(hidden_sizes) == 2, print("Should be only 2 hidden layers")
        super(ClassifierMLP, self).__init__()
        self.in_size = in_size
        self.fc1 = nn.Linear(in_features=in_size, out_features=hidden_sizes[0])
        self.fc2 = nn.Linear(in_features=hidden_sizes[0], out_features=hidden_sizes[1])
        self.classifier = nn.Linear(in_features=hidden_sizes[1], out_features=num_classes)
        self.dropout = nn.Dropout(p=dropout_p)

    def forward(self, x):
        x = x.view(-1, self.in_size)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        return torch.softmax(self.classifier(x), dim=-1)

Although when trained, the validation accuracy is around 50%, which is fine because there is almost 46~ noise, when testing it on the clean test data it gives ~90%.

Same code in tensorflow, gives around ~75%.

Ill add the full code for the training:

def train_step(model, train_loader, opt, loss_fn, device, epoch_num, log_interval, batch_size, stats):
    model.train()
    for batch_idx, (x, t) in enumerate(train_loader):
        opt.zero_grad()
        y_hat = model(x.to(device))
        loss = loss_fn(y_hat, t.to(device))
        loss.backward()
        opt.step()

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch_num, batch_idx * len(x), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))

            stats['train loss'].append(loss.item())
            stats['train iter'].append((batch_idx * batch_size) + ((epoch_num - 1) * len(train_loader.dataset)))
    return model


def test_step(model, test_loader, loss_fn, device, stats):
    model.eval()
    test_loss = 0
    num_correct = 0
    with torch.no_grad():
        for x, t in test_loader:
            y_hat = model(x.to(device))
            test_loss += loss_fn(y_hat, t.to(device)).item()
            y_hat = y_hat.data.max(1, keepdim=True)[1]
            num_correct += y_hat.eq(t.data.view_as(y_hat)).sum()
    test_loss /= len(test_loader.dataset)

    stats['test loss'].append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, num_correct, len(test_loader.dataset),
        100. * num_correct / len(test_loader.dataset)))

    return 100. * num_correct / len(test_loader.dataset)


if __name__ == '__main__':
    batches = {"train": 256, "test": 256}
    n_epochs = 20
    log_interval = 100
    learning_rate = 0.001
    momentum = 0.5
    n_hidden = [256, 64]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_loader, val_loader, test_loader = getMNISTData(batches)

    model = ClassifierMLP(in_size=784, hidden_sizes=n_hidden, num_classes=10, dropout_p=0.5)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    stats_dict = {"train loss": [],
                  "train iter": [],
                  "test loss": []}

    os.makedirs("results_base", exist_ok=True)
    dir_name = "results_base"

    best_acc = 0.

    best_acc = test_step(model,
                         test_loader,
                         loss_fn,
                         device,
                         stats_dict)
    best_acc = 0.
    for epoch in range(1, n_epochs + 1):
        model = train_step(model,
                           train_loader,
                           optimizer,
                           loss_fn,
                           device,
                           epoch,
                           log_interval,
                           batches['train'],
                           stats_dict)

        acc = test_step(model,
                        val_loader,
                        loss_fn,
                        device,
                        stats_dict)

        if acc >= best_acc:
            best_acc = acc
            torch.save(model.state_dict(), os.path.join(dir_name, 'model.pth'))
            torch.save(optimizer.state_dict(), os.path.join(dir_name, 'optimizer.pth'))

    os.path.join(dir_name, "stats.npy")
    np.save(os.path.join(dir_name, "stats.npy"), stats_dict)

If the tensorflow code is also needed Ill add it later.

would love to see the tensorflow version!