Mixed precision(autocast) error for a simple network

I’m trying out a simple model setup with autocast using the latest nightly version 1.6.0.dev20200406. The model has a single Conv2d and a GRU layer with float32 input wrapped inside the autocast context as below:

class Net(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, (3, 3), padding=(1, 1))
        self.gru = torch.nn.GRU(dim, 1024, batch_first=True)

    def forward(self, x):
        with torch.cuda.amp.autocast():
            out = self.conv(x)
            out = out.reshape(-1, out.size(1) * out.size(2), out.size(3))
            out, _ = self.gru(out)
            return out

model = Net(512).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=.9)
scaler = torch.cuda.amp.GradScaler()
optimizer.zero_grad()

x = torch.rand((32, 1, 1600, 512), dtype=torch.float32).cuda()
y = model(x)

This doesn’t work and throws the following trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-8fb506ecd9e7> in <module>
     17 
     18 x = torch.rand((32, 1, 1600, 512), dtype=torch.float32).cuda()
---> 19 y = model(x)

~/miniconda3/envs/torch-test/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    556             result = self._slow_forward(*input, **kwargs)
    557         else:
--> 558             result = self.forward(*input, **kwargs)
    559         for hook in self._forward_hooks.values():
    560             hook_result = hook(self, input, result)

<ipython-input-11-8fb506ecd9e7> in forward(self, x)
      8             out = self.conv(x)
      9             out = out.reshape(-1, out.size(1) * out.size(2), out.size(3))
---> 10             out, _ = self.gru(out)
     11             return out
     12 

~/miniconda3/envs/torch-test/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    556             result = self._slow_forward(*input, **kwargs)
    557         else:
--> 558             result = self.forward(*input, **kwargs)
    559         for hook in self._forward_hooks.values():
    560             hook_result = hook(self, input, result)

~/miniconda3/envs/torch-test/lib/python3.7/site-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    725         if batch_sizes is None:
    726             result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
--> 727                              self.dropout, self.training, self.bidirectional, self.batch_first)
    728         else:
    729             result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,

RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM

GRU expects a float32 input and so explicitly typecasting the Conv2d output to float32 resolves this. I’m migrating from Apex, and I know that this shouldn’t be the case with autocast API as well. I’d like to know if I’m missing something here or if it is a limitation with the current version. @mcarilli

PyTorch version = 1.6.0.dev20200406+cu101
CuDNN enabled
CUDA = 10.1
Os = Ubuntu 19.10

Glad you’re trying the native API, and thanks for the minimal repro. I think this indicates my wrappers aren’t reflattening weights in the way cudnn rnns require. Cudnn rnn support is a high priority, I’ll look at it soon. In the meantime, the cell-based RNN API should work.

@mcarilli I can confirm this also happens with 1.6.0.dev20200430. Also it affects LSTM as well (my case), while the manually permuting inputs and changing batch_first doesn’t help (in the past some things were broken with batch_first=True).
Also, not sure how to reach you regarding specific use case with variable batch size. So I read materials on AMP and to squeeze out max performance all dimensions on layers like Linear or Conv2D must be divisible by 8 and batch_size as well. Is there any kind of workaround to get decent performance with variable batch size? I can ask completely new question if that is preferred.

Fix is underway https://github.com/pytorch/pytorch/issues/36428#issuecomment-665748638

1 Like