Hi,
I would like to use torch.where
to insert a default value in my tensor when some exception occurs. Here is a minimal example:
>>> a = torch.tensor([3,1,5])
>>> b = torch.tensor([2,1,0,2,3,1])
>>> torch.where(b < len(a), a[b], -1) # expects tensor([5,1,3,5,-1,1])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: index 3 is out of bounds for dimension 0 with size 3
Is there an alternative to that piece of code which evaluates a[b]
only on indices i such that b[i]
satisfies the condition?