Short-circuiting torch.where?

Does torch.where(cond, x, y) evaluates all of x and all of y always? Is there a way to avoid computing elements of x where cond is false?

Isn’t x already a full Tensor? So you have to compute it before giving it to torch.where() right?

I am actually doing something like torch.where(cond, f(x), g(y)), where f and g act element-wise and can be expensive. For example see the definitions of logerfc and logerfcx in this code snippet: How to call SciPy functions? - #3 by f3ba

I am actually doing something like torch.where(cond, f(x), g(y)) , where f and g act element-wise and can be expensive

f and g will be evaluated before their results are passed to where. To not compute values for certain elements, the implementations of f and g would need to know about cond.

Maybe you could do something like f(x.masked_select(cond)) using cond as a mask, and use the indices in cond to convert the result of f back to the original dimensions.

E.g. something like:

indices=torch.nonzero(mask)[:, 0].unsqueeze(0)
values=f(x.masked_select(mask)) # only runs f on the selected values
sparse = torch.sparse_coo_tensor(indices, values, x.size())
result=sparse.to_dense()