I am converting code from Tensorflow 1 to PyTorch. I want to convert a tf.where
operation. PyTorch’s torch.where
function is equivalent to tf.where
in Tensorflow 2, but not in Tensorflow 1, which does not yet support broadcasting similar to np.where
, as one can see from the example below:
import torch
condition = torch.tensor([True, True, True, True])
x = torch.tensor([[0.],
[0.],
[0.],
[0.]])
y = torch.tensor([[1.],
[0.],
[0.],
[0.]])
torch.where(condition, x, y)
gives as result:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
whereas in Tensorflow 1:
import tensorflow.compat.v1 as tf
condition = tf.convert_to_tensor([True, True, True, True])
x = tf.convert_to_tensor([[0.],
[0.],
[0.],
[0.]])
y = tf.convert_to_tensor([[1.],
[0.],
[0.],
[0.]])
tf.where(condition, x, y)
the result is
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[0.],
[0.],
[0.],
[0.]], dtype=float32)>
So my question is: how to get the same result as in Tensorflow 1 with PyTorch?