Hi, I want to learn more about PyTorch, understand and discuss its design.
The issue
The shape of a single element tensor can be confusing.
a = torch.Tensor([1]) #torch.Size([1])
b = a[0] #torch.Size([])
#but
a.nelement() == b.nelement()
Why does it matter
a and b are both tensors with a single element but they are fundamentally different.
This makes things like porting greedy algorithm into pytorch quite tricky as nelement()
return 1 for both a and b but indexing b will result in errors.
What might be good
A simple solution might be just letter indexing a single element of a tensor be its own tensor,
a = torch.Tensor([1])
a == a[0]
But I am sure you probably have already thought of it, I am just curious why is it not that way.
Thanks!