J_Johnson
(J Johnson)
3
There is a torch.where
statement that is the if equivalent for tensors.
Usage:
a = torch.rand((128, 3, 64, 64))
b = torch.rand((128, 3, 64, 64))
c = torch.where(a<b, a, b)
https://pytorch.org/docs/stable/generated/torch.where.html
Alternatively, if you need to maintain the graph, you can either write custom backward functions, as described here:
OR construct your logic to use functions that have backward well defined. For example:
b_larger = a<b
c = b_larger*a + (1-b_larger)*b