ReflectionPad2D: "input tensor must fit into 32-bit index math"

Hi everybody, I just stumbled upon that error message and have absolutely no idea what to make of it.

I’m trying to use the generator part of GitHub - znxlwm/UGATIT-pytorch: Official PyTorch implementation of U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation (a GAN). While it works fine on my machine, when I started it on our cluster (A100 GPU) I received this error:

  File "/workspace/src/LPR/architecture/", line 107, in forward
    out = self.UpBlock2(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/", line 117, in forward
    input = module(input)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/", line 170, in forward
    return F.pad(input, self.padding, 'reflect')
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/", line 3570, in _pad
    return torch._C._nn.reflection_pad2d(input, pad)
RuntimeError: input tensor must fit into 32-bit index math

I’m using PyTorch Lightning and am building upon the pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel docker.

I have no idea how to even interpret the error message, and google turns up practically nothing.

My first intuition was to turn 16-bit precision off, but the problem remains. Any pointers are welcome.

What is your input size and stride and the pad?
The error says that your input (or padded input) doesn’t fit into 32 bit index(!) math, i.e. it uses memory offsets larger than >= 2^31. If you feed large batches or number of channels you could work around it by passing in the input by tensors. The alternative would be to fix PyTorch to use 64 bit indexing in the reflection pad kernels (or just copy their kernels and do this in an extension module).
The background is that 64 bit index math can be significantly slower on GPUs than 32 bit and so some code is only implemented in 32 bit index math and other code checks and falls back to 64 bit only when needed. A third part only uses 64 bit because the index math speed is not considered an important enough factor to warrant avoiding it. One would need to benchmark the difference to see which is the case here.

Best regards


Investigating this is a little troublesome, since I actually just went on to use another architecture now.

Replicating it with my machine (which runs with CPU by the way), the inputs for the 3 Pads are:

[512, 64, 46, 140] (1,1,1,1)
[512, 32, 92, 280] (1,1,1,1)
[512, 16, 92, 280] (3,3,3,3)

But like I said, it works on my machine. So I can’t determine from here which of the three were the culprit.

Nevertheless, your explanation sounds like it’s more an issue of exploding gradients/values, so input sizes won’t tell you anything anyway, will they?

Hey, I just stumbled across this issue because I got the same error, also using pytorch-lightning and pytorch 1.6.0. Reducing the batch size helped in my case.