NaN loss after few epochs of fine tuning ViT B16 on smaller datasets

I am fine-tuning a pretrained ViT on CIFAR100 (resizing to 224), the training starts out well with decreasing loss and decent accuracy. But then suddenly the loss goes to NaN with the accuracy equaling random guess.

The learning rate I used was 0.0001 and Adam optimizer. With a learning rate of 0.001 a similar issue occurs, but a few epochs earlier.

This is my code:

model = torchvision.models.vit_b_16(weights='IMAGENET1K_V1')
model.to(device) 

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

for epoch in range(total_epochs):
    model.train()
    train_loss = 0
    correct = 0
    for i, (x, y) in enumerate(trainloader):
        optimizer.zero_grad()
        x = x.to(device)
        y = y.to(device)
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
            
        _, predicted_train = outputs.max(1)
        correct += predicted_train.eq(y).sum().item()

When learning rate is 0.0001, NaN values occur after 25~ epochs.
When learning rate is 0.001, NaN values occur after 5~ epochs.

Without any additional information, my guess is that applying L2 regularization through weight decay is leading to a division by zero problem that results in a propagation of NaN values. Have you tried to change its value?

Are you saying the current weight decay value could be high, causing the weights towards zero and cause the division by zero error ?

exactly what I meant

Even with gradient clipping it goes to NaN after a few epochs

can you check if you are normalizing your images with Imagenet mean and variance? @pytorchuser101 plus it generally does not occur if your inputs are CIFAR100

Yes, I am normalizing train and test sets by ImageNet mean and standard deviation:
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)

I tried (0.5, 0.5, 0.5) for mean and std dev as well. It gives the same issue.

I am resizing CIFARs to 224x224