# How can I do the operation the same as `np.where`?

`np.where` is powerful for conditioned element-wise assignment which has been implemented in Tensorflow, I wonder how can I do this in pytorch?

8 Likes

I dont think there’s a pytorch equivalent, but you could look at using `apply_()` and then calling `nonzero()` to gather the indices.

Also, `apply_()` is an inplace operation and only works on CPU tensors.

1 Like

Here’s the Chainer implementation, if anyone’s curious. This could even be directly wrapped in a PyTorch function with a couple casting tricks (PyTorch tensors and cupy arrays have the same memory layout).

https://github.com/pfnet/chainer/blob/master/cupy/sorting/search.py#L92

1 Like

I actually solved it by a workaround:

``````# differentiable equivalent of np.where
# cond could be a FloatTensor with zeros and ones
def where(cond, x_1, x_2):
return (cond * x_1) + ((1-cond) * x_2)
``````
17 Likes

@jekbradbury I still don’t quite get how I can implement np.where efficiently. Can you kindly help?

Is there any update on a builtin implementation of `numpy.where`?

1 Like

@truenicoco, `torch.where()` looks like it’'s in 0.4! http://pytorch.org/docs/master/torch.html#torch.where

thanks others for examples of other options for 0.3 and before

5 Likes

here’s an example of using @jaromiru’s workaround:

``````## this example prunes any elements in the vector above or below the bounds
vec = torch.FloatTensor([-3.2, 1000.0, 10.0, 639.0])
lower_bound = 0.0
upper_bound = 640.0
lower_bound_vec = torch.ones_like(vec) * lower_bound
upper_bound_vec = torch.ones_like(vec) * upper_bound
zeros_vec       = torch.zeros_like(vec)

def where(cond, x_1, x_2):
cond = cond.float()
return (cond * x_1) + ((1-cond) * x_2)

vec = where(vec < lower_bound_vec, zeros_vec, vec)
vec = where(vec > upper_bound_vec, zeros_vec, vec)

in_bound_indices = torch.nonzero(vec).squeeze(1)
vec = torch.index_select(vec, 0, in_bound_indices)
``````

if you cast `cond` to a FloatTensor inside `where()`, then you can just use it the same as `np.where()`

note that `.clamp()` could be used together with `nonzero()` and `index_select()` to prune out the lower bound of 0 directly

2 Likes

Would `cond = cond.type(torch.FloatTensor)` move a GPU tensor back to the CPU?

Would it not be better to do `cond = cond.float()`?

Yes good call, thanks! Your suggestion is better than what I did later, which is use `cond = cond.type(dtype_float)` where I switch off `dtype_float = torch.cuda.FloatTensor` or `torch.FloatTensor`. `.float()` will happily keep it on whichever device (CPU/GPU). I’ll edit that code snippet so if anybody copy/pastes, that’s already handled.