Tf.where equivalent in Pytorch

An example of Tf.where usage is like this:

import tensorflow as tf
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# Constants (3-element arrays).
a = tf.constant([100, 200, 300])
b = tf.constant([1, 2, 3])

# Use placeholder for predicate to where.
# ... We pass in an array of 3 bools to fill the placeholder.
j = tf.placeholder(tf.bool, [3])

# Use where to apply 1 of 2 methods based on each predicate.
# ... First argument is the predicate (contains bools).
#     Second argument is run when true.
#     Third argument is run when false.
x = tf.where(j, a + 5000, a + b)

# Run with 3 bools in placeholder.
array_temp = [False, True, False]
result = tf.Session().run(x, {j: array_temp})

# For false, add 2 elements toe get her.
# ... For true, add 5000 to first element.
print(result)

Not sure if there is something that can do the same job like tf.where in this case.

I think torch.where should do the same as tf.where.

1 Like

Yes, just found this as the solution. Thanks !!!

@ptrblck @bdqnghi

Can I ask what is the equivalent pytorch of tf.constant

You could use torch.tensors directly.