Does torch.where(cond, x, y)
evaluates all of x
and all of y
always? Is there a way to avoid computing elements of x
where cond
is false?
Isn’t x
already a full Tensor? So you have to compute it before giving it to torch.where()
right?
I am actually doing something like torch.where(cond, f(x), g(y))
, where f
and g
act element-wise and can be expensive. For example see the definitions of logerfc
and logerfcx
in this code snippet: How to call SciPy functions? - #3 by f3ba
I am actually doing something like
torch.where(cond, f(x), g(y))
, wheref
andg
act element-wise and can be expensive
f and g will be evaluated before their results are passed to where
. To not compute values for certain elements, the implementations of f
and g
would need to know about cond
.
Maybe you could do something like f(x.masked_select(cond))
using cond as a mask, and use the indices in cond to convert the result of f back to the original dimensions.
E.g. something like:
indices=torch.nonzero(mask)[:, 0].unsqueeze(0)
values=f(x.masked_select(mask)) # only runs f on the selected values
sparse = torch.sparse_coo_tensor(indices, values, x.size())
result=sparse.to_dense()