File "/workspace/src/LPR/architecture/networks.py", line 107, in forward
out = self.UpBlock2(x)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/padding.py", line 170, in forward
return F.pad(input, self.padding, 'reflect')
File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3570, in _pad
return torch._C._nn.reflection_pad2d(input, pad)
RuntimeError: input tensor must fit into 32-bit index math
I’m using PyTorch Lightning and am building upon the pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel docker.
I have no idea how to even interpret the error message, and google turns up practically nothing.
My first intuition was to turn 16-bit precision off, but the problem remains. Any pointers are welcome.
What is your input size and stride and the pad?
The error says that your input (or padded input) doesn’t fit into 32 bit index(!) math, i.e. it uses memory offsets larger than >= 2^31. If you feed large batches or number of channels you could work around it by passing in the input by tensors. The alternative would be to fix PyTorch to use 64 bit indexing in the reflection pad kernels (or just copy their kernels and do this in an extension module).
The background is that 64 bit index math can be significantly slower on GPUs than 32 bit and so some code is only implemented in 32 bit index math and other code checks and falls back to 64 bit only when needed. A third part only uses 64 bit because the index math speed is not considered an important enough factor to warrant avoiding it. One would need to benchmark the difference to see which is the case here.
Hey, I just stumbled across this issue because I got the same error, also using pytorch-lightning and pytorch 1.6.0. Reducing the batch size helped in my case.