PyTorch equivalent of tf.where in Tensorflow 1

I am converting code from Tensorflow 1 to PyTorch. I want to convert a tf.where operation. PyTorch’s torch.where function is equivalent to tf.where in Tensorflow 2, but not in Tensorflow 1, which does not yet support broadcasting similar to np.where, as one can see from the example below:

import torch

condition = torch.tensor([True, True, True, True])

x = torch.tensor([[0.],
        [0.],
        [0.],
        [0.]])

y = torch.tensor([[1.],
        [0.],
        [0.],
        [0.]])

torch.where(condition, x, y)

gives as result:

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

whereas in Tensorflow 1:

import tensorflow.compat.v1 as tf

condition = tf.convert_to_tensor([True, True, True, True])

x = tf.convert_to_tensor([[0.],
        [0.],
        [0.],
        [0.]])

y = tf.convert_to_tensor([[1.],
        [0.],
        [0.],
        [0.]])

tf.where(condition, x, y)

the result is

<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[0.],
       [0.],
       [0.],
       [0.]], dtype=float32)>

So my question is: how to get the same result as in Tensorflow 1 with PyTorch?

Turns out to use the following:

torch.where(condition.view(x.size()), x, y)