Hi. I need some help on this error.
I’m encountering some CUDA error with a Transformer model. Here is a very simple example to reproduce the error. Can someone help with this error? This code works fine without AMP on and also for any typical ConvNets. I wonder if this is anything AMP-incompatible stuff in pytorch Transformer.
NVIDIA-SMI 418.40.04 Driver Version: 418.40.04 CUDA Version: 10.1 Tesla V100-SXM2
Python 3.7.6 (default, Jan 8 2020, 19:59:22)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type “help”, “copyright”, “credits” or “license” for more information.
import torch
print(torch.version)
1.6.0+cu101
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class SimpleVT(nn.Module):
def __init__(self):
super().__init__()
self.enc_layers = TransformerEncoderLayer(40, 2, 20, 0.5)
self.encoder = TransformerEncoder(self.enc_layers, 2)
self.decoder = nn.Linear(40, 2)
def forward(self, x):
x = self.enc_layers(x)
x = self.encoder(x)
x = self.decoder(x)
return x
model = SimpleVT().cuda()
x = torch.rand([64, 102, 40]).cuda()
y_hat = torch.rand([64, 102, 2]).cuda()
with torch.cuda.amp.autocast(enabled=True):
y = model(x)
loss = nn.MSELoss()(y_hat, y)
loss.backward()
print(y.size())
This code fails with the following error.