This isn’t really a bug, but rather a quirk of how how PyTorch optimizes memory layout for single-element results from nonzero()
.
The key difference is that:
- For single elements, PyTorch uses a compact stride of (1, 1)
- For multiple elements, it uses the expected stride of (3, 1)
This different stride pattern doesn’t make the operations wrong, but it can be surprising when you’re explicitly working with strides.
To solve this issue, if you need consistent stride behavior you can convert the single element case to match the multi-element case. I would try:
nonzero = torch.nonzero(example)
if nonzero.shape[0] == 1:
nonzero = nonzero.clone().view(1, -1) # This will give you (3, 1) stride