Hello everyone!
After hours of debugging I noticed that .stride()
behaves (in my opinion) strangely when applied to the output of a .nonzero()
operation. Here is a minimal working example of what I am encountering:
import torch
def test_stride():
print()
example = torch.tensor([[[False, True],
[False, False]]])
nonzero = torch.nonzero(example).contiguous()
print(f"Nonzero output: {nonzero}") # Nonzero output: tensor([[0, 0, 1]])
print(f"Nonzero stride: {nonzero.stride()}") # Nonzero stride: (1, 1)
equivalent = torch.tensor([[0, 0, 1]]).contiguous()
print(f"Equivalent output: {equivalent}") # Equivalent output: tensor([[0, 0, 1]])
print(f"Equivalent stride: {equivalent.stride()}") # Equivalent stride: (3, 1)
I would assume that in both cases the stride of the tensor is (3, 1)
, but no matter what I do to the output tensor of the .nonzero()
operation the stride stays (1, 1)
. This is only the case as long as it contains a single element; as soon as two or more elements are returned the output is as expected.
Is this a bug or am I overlooking something? If the latter is the case please tell me how to resolve my issue. Thanks!