RuntimeError: Expected canUse32BitIndexMath(input)

Hi,
I’m trying to use this repo on a nvidia A6000 48Gb. I get the following error:

RuntimeError: Expected canUse32BitIndexMath(input) && canUse32BitIndexMath(output) to be true, but got false.
(Could this error message be improved? If so, please report an enhancement request to PyTorch.)

After some investigation, I understand that this issue occurs because large images can result in an output tensor exceeding 2^31 elements, which is not supported by cuDNN versions earlier than 9.3 (related PyTorch issue and PR). Since I am performing inference, I cannot split the batch to work around this limitation.

The problem is that starting from today I’m using pytorch 2.7 which comes with cuDNN nvidia-cudnn-cu12==9.5.1.17 and should not have this issue.

Any hints ?

Were you able to create a stacktrace pointing to the failing call? The linked source file does not use canUse32BitIndexMath directly, so I’m unsure if your conv fails or another kernel.

Here you are:

  Traceback (most recent call last):
    File "/home/vrai/Reti-Diff/Reti-Diff/test.py", line 60, in <module>
      test_pipeline(root_path)
    File "/home/vrai/Reti-Diff/Reti-Diff/test.py", line 53, in test_pipeline
      model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/basicsr/models/base_model.py", line 48, in validation
      self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
    File "/home/vrai/Reti-Diff/Reti-Diff/models/S2_Interface_model.py", line 310, in nondist_validation
      self.test()
    File "/home/vrai/Reti-Diff/Reti-Diff/models/S2_Interface_model.py", line 409, in test
      self.output = self.net_g(lq, retinex_lq)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/vrai/Reti-Diff/Reti-Diff/archs/S2_interface_arch.py", line 562, in forward
      sr = self.G(img, IPRS2_rex, IPRS2_img)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/vrai/Reti-Diff/Reti-Diff/archs/S2_interface_arch.py", line 352, in forward
      out_dec_level1, _ = self.decoder_level1([inp_dec_level1, k_v])
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 240, in forward
      input = module(input)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/vrai/Reti-Diff/Reti-Diff/archs/S2_interface_arch.py", line 217, in forward
      x = x + self.ffn(self.norm2(x), k_v)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/vrai/Reti-Diff/Reti-Diff/archs/S2_interface_arch.py", line 89, in forward
      x1, x2 = self.dwconv(x).chunk(2, dim=1)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 554, in forward
      return self._conv_forward(input, self.weight, self.bias)
    File "/home/vrai/Reti-Diff/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
      return F.conv2d(
  RuntimeError: Expected canUse32BitIndexMath(input) && canUse32BitIndexMath(output) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

No issue reproducible with torch==2.7.0+cu126 using:

import torch

device = "cuda"
x = torch.randn(1024, 1, 1024*8, 1024, device=device)
weight = torch.randn(1, 1, 1, 1, device=device)

print(x.numel() > 2**31 - 1)

out = torch.nn.functional.conv2d(x, weight, stride=1)
print(out.shape)
True
torch.Size([1024, 1, 8192, 1024])

I tried your code and I don’t have the issue. I have OOM, but that’s normal.
Thanks for the reply, I’m not an expert and I’d like to dig deeper into this..
Since the code is calling the original conv2d function (last row of the stacktrace).. where could the problem located ?

A quick search points to these methods so maybe you are using a depthwise conv?

I am running into the same error (num elements exceeding 2^31) when using latest pytorch 2.7.0+ cudnn > 9.3 and depthwise conv.

import torch
import torch.nn as nn

device = torch.device("cuda")

# Define an extremely large input tensor (exceeding 2**31 elements for a single sample), use grouped (depthwise separable) convolutions
# For example: Batch size = 1, Channels = 2, Height = 32,800, Width = 32,800
# Total elements = 1 * 2 * 32,800 * 32,800 = 2,151,680,000 > 2**31 (2,147,483,648)
num_channels=2
input_tensor = torch.randn(1, num_channels, 32800, 32800, device=device)
conv_layer = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1, groups=num_channels).to(device)
output_tensor = conv_layer(input_tensor)

What would be needed to extend the support of tensors with > 2^31 to groups>1 (depthwise separable convs)? Are there any plans to support this in the future?

I have also opened an issue related to this in the PyTorch Github: Depthwise Separable Convolutions with Large Tensors (> 2**31) Elements) Fail Despite cuDNN 64-bit Indexing Support · Issue #152816 · pytorch/pytorch · GitHub