Looks like nonzero on a 2d tensor will return the coordinates of the non-zero elements? Is there a function which can take these coordinates, and a list of values, and recreate the original tensor?
You could use the following code:
x = torch.empty(5, 5).random_(3)
idx = x.nonzero()
y = torch.zeros(5, 5)
y[idx[:,0], idx[:,1]] = x[idx[:, 0], idx[:, 1]]
print((x==y).all())
I’m not sure, if this issue was already solved, as x[x.nonzero()]
seems not to be supported at the moment.
2 Likes
This works great. Thanks!
1 Like