Error in CNN Training with MNIST but not CIFAR10(0)

Hello everyone,
I’m currently following this:

and running main_frequentist.py:

from __future__ import print_function

import os
import argparse

import torch
import numpy as np
import torch.nn as nn
from torch.optim import Adam

import data
import config_frequentist as cfg
from models.NonBayesianModels.AlexNet import AlexNet
from models.NonBayesianModels.LeNet import LeNet
from models.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC

# CUDA settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def getModel(net_type, inputs, outputs):
    if (net_type == 'lenet'):
        return LeNet(outputs, inputs)
    elif (net_type == 'alexnet'):
        return AlexNet(outputs, inputs)
    elif (net_type == '3conv3fc'):
        return ThreeConvThreeFC(outputs, inputs)
    else:
        raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')


def train_model(net, optimizer, criterion, train_loader):
    train_loss = 0.0
    net.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = net(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0)
    return train_loss


def validate_model(net, criterion, valid_loader):
    valid_loss = 0.0
    net.eval()
    for data, target in valid_loader:
        data, target = data.to(device), target.to(device)
        output = net(data)
        loss = criterion(output, target)
        valid_loss += loss.item()*data.size(0)
    return valid_loss


def run(dataset, net_type):

    # Hyper Parameter settings
    '''
    n_epochs = 50
    lr = 0.01
    num_workers = 4
    valid_size = 0.2
    batch_size = 256

    '''

    n_epochs = cfg.n_epochs
    lr = cfg.lr
    #beta_type = cfg.beta_type
    #num_workers = cfg.num_workers
    num_workers = 0
    valid_size = cfg.valid_size
    batch_size = cfg.batch_size
    trainset, testset, inputs, outputs = data.getDataset(dataset)
    train_loader, valid_loader, test_loader = data.getDataloader(
        trainset, testset, valid_size, batch_size, num_workers)
    net = getModel(net_type, inputs, outputs).to(device)
    ckpt_dir = f'checkpoints/{dataset}/frequentist'
    ckpt_name = f'checkpoints/{dataset}/frequentist/model_{net_type}.pt'

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(net.parameters(), lr=lr)
    valid_loss_min = np.Inf
    for epoch in range(1, n_epochs+1):
        train_loss = train_model(net, optimizer, criterion, train_loader)
        valid_loss = validate_model(net, criterion, valid_loader)

        train_loss = train_loss/len(train_loader.dataset)
        valid_loss = valid_loss/len(valid_loader.dataset)
            
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, train_loss, valid_loss))
        
        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_min, valid_loss))
            torch.save(net.state_dict(), ckpt_name)
            valid_loss_min = valid_loss

dataset = 'MNIST'
net_type = 'alexnet'

parser = argparse.ArgumentParser(description = "PyTorch Frequentist Model Training")
parser.add_argument('-f')
parser.add_argument('--net_type', default=net_type, type=str, help='model')
parser.add_argument('--dataset', default=dataset, type=str, help='dataset = [MNIST/CIFAR10/CIFAR100]')
args = parser.parse_args()

print(args.dataset)
print(args.net_type)
run(args.dataset, args.net_type)
#run(dataset, net_type)

When I use the CIFAR10 and CIFAR100 datasets, the code works. When I switch it to MNIST, I get this error:

TypeError                                 Traceback (most recent call last)
<ipython-input-34-f44957672f34> in <module>()
    115 print(args.dataset)
    116 print(args.net_type)
--> 117 run(args.dataset, args.net_type)
    118 #run(dataset, net_type)

12 frames
/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color)
   2542         im.palette = ImagePalette.ImagePalette()
   2543         color = im.palette.getcolor(color)
-> 2544     return im._new(core.fill(mode, size, color))
   2545 
   2546 

TypeError: function takes exactly 1 argument (3 given)

I suspect it is because MNIST is greyscale (one channel, no RGB). I’ve also tried setting
num_workers = 0 but that still gave me the same error. I get this error regardless of architecture (alexnet, lenet, 3conv3fc).

I’ve also tried running his main_bayesian.py and the same thing happens for MNIST with a Bayesian CNN (works with CIFAR10 and CIFAR100 though).

The getDataset() and getDataloader() methods are defined below so you can see the transformations applied to the data.

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

def getDataset(dataset):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        ])

    if(dataset == 'CIFAR10'):
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
        inputs=3

    elif(dataset == 'CIFAR100'):
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        num_classes = 100
        inputs = 3
        
    elif(dataset == 'MNIST'):
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        print('hello')
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
        inputs = 1

    return trainset, testset, inputs, num_classes

def getDataloader(trainset, testset, valid_size, batch_size, num_workers=0):
    num_train = len(trainset)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(valid_size * num_train))
    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
        sampler=train_sampler, num_workers=num_workers)
    valid_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 
        sampler=valid_sampler, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 
        num_workers=num_workers)

    return train_loader, valid_loader, test_loader

I’ve tried following the advice here:
Error message following update to torchvision 0.5 "function takes exactly 1 argument (3 given)" but didn’t get anywhere.

I’m running this in Google Colab with PyTorch 1.4.0.

Thanks so much!

I assume you’ve added the fill argument to RandomRotation and are still seeing this error?
Could you try to install the nightly binaries and rerun the code?

Yes, I did indeed try to add fill=(0,) to RandomRotation and I got the exact same error.

Could you please specify what ‘nightly binaries’ I need to install? It’s the first time I’ve heard of these.
Thanks!

Edit: I fixed the problem by downgrading my torchvision and my pillow:

!pip install torchvision==0.4.2
!pip install pillow==6.2.0

from 0.5.0 and 7.0.0 respectively.
But downgrading isn’t the way I want to go about it…

Nightly binaries are build from the current master on a daily basis and can be found in the “Preview (Nightly)” tab in the install instructions.

I cannot reproduce this error locally with the fix, so I would recommend to use the nightly as a workaround.

Thank you, I will try that and let you know how it goes :slight_smile:

Just as a follow-up:
I had settled for downgrading because of the simplicity…