Pytorch amp CUDA error with Transformer

Hi. I need some help on this error.

I’m encountering some CUDA error with a Transformer model. Here is a very simple example to reproduce the error. Can someone help with this error? This code works fine without AMP on and also for any typical ConvNets. I wonder if this is anything AMP-incompatible stuff in pytorch Transformer.

NVIDIA-SMI 418.40.04 Driver Version: 418.40.04 CUDA Version: 10.1 Tesla V100-SXM2

Python 3.7.6 (default, Jan 8 2020, 19:59:22)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type “help”, “copyright”, “credits” or “license” for more information.

import torch

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class SimpleVT(nn.Module):

def __init__(self):
    self.enc_layers = TransformerEncoderLayer(40, 2, 20, 0.5)
    self.encoder = TransformerEncoder(self.enc_layers, 2)
    self.decoder = nn.Linear(40, 2)
def forward(self, x):
    x = self.enc_layers(x)
    x = self.encoder(x)
    x = self.decoder(x)
    return x

model = SimpleVT().cuda()

x = torch.rand([64, 102, 40]).cuda()
y_hat = torch.rand([64, 102, 2]).cuda()

with torch.cuda.amp.autocast(enabled=True):
y = model(x)
loss = nn.MSELoss()(y_hat, y)


This code fails with the following error.

I cannot reproduce it on a V100-SMX2 32GB using PyTorch 1.6.0 and the CUDA10.1 binaries and get a valid output of torch.Size([64, 102, 2]).

EDIT: the only difference seems to be the driver, so could you update it?

Can you tell me which version worked for you?

I’m using 450.51.06.