I am trying to use the vit_b_16
torchvision model with AMP (torch.cuda.amp
/ torch.autocast
):
https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html
I am encountering the error:
Traceback (most recent call last):
File "/home/phil/Code/ReLish/benchmark/train_cls.py", line 95, in main
train_model(C, train_loader, valid_loader, model, output_layer, criterion, optimizer, scheduler)
File "/home/phil/Code/ReLish/benchmark/train_cls.py", line 455, in train_model
output = model(data)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
x = self.encoder(x)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
return self.ln(self.layers(self.dropout(input)))
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/phil/anaconda3/envs/relish/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1113, in forward
return torch._native_multi_head_attention(
RuntimeError: expected scalar type Half but found Float
The error occurs in the first epoch AFTER training is complete, right when the validation should start. Without AMP, no error occurs. With every other model I’ve tried (including Swin transformers) there is no issue.
I have found this, but am not sure how it would apply:
Maybe it’s something similar to this?
This is the code I’m using:
For reference I call it as (Python 3.9, PyTorch 1.12.1, CUDA 11.6, cuDNN 8.3.2):
./train_cls.py --act_func=original --batch_size=32 --dataset=Imagenette --epochs=120 --model=vit_b_16
Adding --no_amp
to the command line turns AMP off.
Can anyone help?