Hello,
I would like to use the torch.where
function on a sparse tensor.
Is there any torch.sparse function implemented for that (or will be)?
In the other case: Is there a work-around?
Example
i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
v = torch.FloatTensor([3, 4, 0 ])
x = torch.sparse.FloatTensor(i.t(), v, torch.Size([3,3]))
# Working
torch.where(x.to_dense()>0,1,0)
# Not working
torch.where(x>0,1,0)