RuntimeError: n cannot be greater than 2^24+1 for Float type

So, this code results in an error message:

import torch
from torch.utils.cpp_extension import load
cpp = torch.utils.cpp_extension.load(name="histogram_cpp", sources=["histogram.cpp", "histogram.cu"])

dtype=torch.cuda.FloatTensor

a = torch.randn(128, 512, 720).type(dtype)
b = torch.randn(128, 256).type(dtype)

cpp.matchHistogram(a, b) # Causes error message

The matchHistogram() function comes from these files:

https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/63e935bf8315080e2a7b7b002459141cfc3bffe8/histogram.cpp
https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/63e935bf8315080e2a7b7b002459141cfc3bffe8/histogram.cu

This is the full error message:

Traceback (most recent call last):
  File "t.py", line 14, in <module>
    cpp.matchHistogram(a, b)
RuntimeError: n cannot be greater than 2^24+1 for Float type. (check_supported_max_int_with_precision at /pytorch/aten/src/ATen/native/TensorFactories.h:78)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f0ff5f36813 in /usr/local/lib/python3.5/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1bb1638 (0x7f0ff8142638 in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so)
frame #2: at::native::randperm_out_cpu(at::Tensor&, long, at::Generator*) + 0x3c (0x7f0ff813ad0c in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x1d9e3e4 (0x7f0ff832f3e4 in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so)
frame #4: at::native::randperm(long, at::Generator*, c10::TensorOptions const&) + 0xab (0x7f0ff81375eb in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so)
frame #5: at::native::randperm(long, c10::TensorOptions const&) + 0xe (0x7f0ff81376ee in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so)
frame #6: <unknown function> + 0x1ecce9b (0x7f0ff845de9b in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so)
frame #7: at::Tensor at::ATenOpTable::callUnboxed<at::Tensor, long, c10::TensorOptions const&>(long, c10::TensorOptions const&) const + 0xb6 (0x7f0ff330e1d4 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #8: <unknown function> + 0x82f69 (0x7f0ff3300f69 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #9: torch::randperm(long, c10::TensorOptions const&)::{lambda()#1}::operator()() const + 0x97 (0x7f0ff3309b81 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #10: torch::randperm(long, c10::TensorOptions const&) + 0x192 (0x7f0ff3309d5c in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #11: matchHistogram(at::Tensor&, at::Tensor&) + 0x10a (0x7f0ff3301696 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #12: <unknown function> + 0x7e653 (0x7f0ff32fc653 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #13: <unknown function> + 0x7b692 (0x7f0ff32f9692 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #14: <unknown function> + 0x77343 (0x7f0ff32f5343 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #15: <unknown function> + 0x77533 (0x7f0ff32f5533 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
frame #16: <unknown function> + 0x6a4a1 (0x7f0ff32e84a1 in /tmp/torch_extensions/histogram_cpp/histogram_cpp.so)
<omitting python frames>
frame #19: python3() [0x540199]
frame #21: python3() [0x60c272]
frame #26: __libc_start_main + 0xf0 (0x7f104d266830 in /lib/x86_64-linux-gnu/libc.so.6)

I am not sure if this is a limitation with PyTorch, C++, and if I can circumvent it in my loss function by resizing the tensors for every layer that the matchHistogram() function is on?

You seem to call randperm with a number > 2^24. I think there is a problem with your code. Can you can which value you give to randperm to make sure it’s what you expect?

a.numel() = 47185920 in the above example. If I use a = torch.randn(128, 512, 256), then a.numel() is equal to 16777216 and that works.

Given that you generate these in float32 Tensor, this would overflow.