If statement with tensor scalar in Pytorch

There is a torch.where statement that is the if equivalent for tensors.

Usage:


a = torch.rand((128, 3, 64, 64))
b = torch.rand((128, 3, 64, 64))

c = torch.where(a<b, a, b)

https://pytorch.org/docs/stable/generated/torch.where.html

Alternatively, if you need to maintain the graph, you can either write custom backward functions, as described here:

OR construct your logic to use functions that have backward well defined. For example:

b_larger = a<b
c = b_larger*a + (1-b_larger)*b