Torchvision model vit_b_16 fails to train with AMP

I am trying to use the vit_b_16 torchvision model with AMP (torch.cuda.amp / torch.autocast):
https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html

I am encountering the error:

Traceback (most recent call last):
  File "/home/phil/Code/ReLish/benchmark/train_cls.py", line 95, in main
    train_model(C, train_loader, valid_loader, model, output_layer, criterion, optimizer, scheduler)
  File "/home/phil/Code/ReLish/benchmark/train_cls.py", line 455, in train_model
    output = model(data)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
    x = self.encoder(x)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
    return self.ln(self.layers(self.dropout(input)))
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
    x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1113, in forward
    return torch._native_multi_head_attention(
RuntimeError: expected scalar type Half but found Float

The error occurs in the first epoch AFTER training is complete, right when the validation should start. Without AMP, no error occurs. With every other model I’ve tried (including Swin transformers) there is no issue.

I have found this, but am not sure how it would apply:

Maybe it’s something similar to this?

This is the code I’m using:

For reference I call it as (Python 3.9, PyTorch 1.12.1, CUDA 11.6, cuDNN 8.3.2):

./train_cls.py --act_func=original --batch_size=32 --dataset=Imagenette --epochs=120 --model=vit_b_16

Adding --no_amp to the command line turns AMP off.

Can anyone help?

1 Like

The error seems to be raised by the MHA layer. Could you check if the latest nightly binary still rauses the error, please?

No, the error is not there under pytorch-nightly/linux-64::pytorch-1.14.0.dev20221011-py3.9_cuda11.6_cudnn8.3.2_0. What is/was the problem, and can I do anything to overcome the issue now, besides moving to a nightly PyTorch release? (the nightly release broke other unrelated parts of my code related to torch.jit.script, which I had to just comment out otherwise it would crash before I have the chance to train)

I don’t know what the fix was but you could check the commit e.g. via git blame to check which issues were fixed and if some of them matched your error.