ViT Poor Accuracy on Imagenet

Hello,

I’ve been trying to train a ViT model on Imagenet, but no matter how long I leave it to train it only achieves about ~6% accuracy. My code is below:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from vit_pytorch import ViT, SimpleViT
import time

def get_params_groups(model):
    regularized = []
    not_regularized = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # we do not regularize biases nor Norm parameters
        if name.endswith(".bias") or len(param.shape) == 1:
            not_regularized.append(param)
        else:
            regularized.append(param)
    return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]

def train():
    batch_size = 256
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    t = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    train_dataset = torchvision.datasets.ImageFolder(root='/datasets/train', transform=t)
    test_dataset = torchvision.datasets.ImageFolder(root='/datasets/val', transform=t)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    num_batches = len(train_dataset) // batch_size

    v = ViT(
        image_size = 224,
        patch_size = 16,
        num_classes = 1000,
        dim = 384,
        depth = 12,
        heads = 6,
        mlp_dim = 384*4
    )

    v = v.to(torch.device('cuda'))
    v.train()
    optimizer = torch.optim.AdamW(get_params_groups(v), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(30):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            print(labels)
            inputs = inputs.to(torch.device('cuda')) # Move the inputs to GPU
            labels = labels.to(torch.device('cuda')) # Move the labels to GPU
            optimizer.zero_grad()
            outputs = v(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 500 == 0:
                print('[Epoch %d, Batch %5d / % 5d] loss: %.3f' % (epoch + 1, i + 1, num_batches, running_loss / 100))
        print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(train_loader)))

        correct = 0
        total = 0
        v.eval()
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                images = images.to(torch.device('cuda')) # Move the images to GPU
                labels = labels.to(torch.device('cuda')) # Move the labels to GPU
                outputs = v(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            print('Accuracy of the model on the test set: %d %%' % (100 * correct / total))

    print('Finished training.')

    correct = 0
    total = 0
    v.eval()
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(torch.device('cuda')) # Move the images to GPU
            labels = labels.to(torch.device('cuda')) # Move the labels to GPU
            outputs = v(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Accuracy of the model on the test set: %d %%' % (100 * correct / total))

if __name__ ==  "__main__" :
    train()

I’m using Lucidrains’ ViT implementation, so I don’t think the issue is the architecture. I’d appreciate any insight into what the issue could be!

I’m unsure which exact architecture you are using, but you could check the torchvision implementation from here which also links to the used recipes used to create the pretrained parameters for ImageNet.