If statement with tensor scalar in Pytorch

Suppose I have two tensor type scalars a,b. I want to compare them in a “if” statement, but when I use “if a<b” then there would be a detach procedure taking around 70ms on my machine. Is there an alternative to realize this? a<b is a tensor bool type and I don’t know how to combine with “if” without detach. Thank you!

The if statement should not detach the tensors but synchronize the code, since you have introduced data-dependent control flow. The host needs to synchronize to be able to read the actual values in a and b before evaluating the if condition.

There is a torch.where statement that is the if equivalent for tensors.

Usage:


a = torch.rand((128, 3, 64, 64))
b = torch.rand((128, 3, 64, 64))

c = torch.where(a<b, a, b)

https://pytorch.org/docs/stable/generated/torch.where.html

Alternatively, if you need to maintain the graph, you can either write custom backward functions, as described here:

OR construct your logic to use functions that have backward well defined. For example:

b_larger = a<b
c = b_larger*a + (1-b_larger)*b

Thank you! But what I’m going to realize after (a<b) is a bit troublesome. Suppose I have a 3x2x2 tensor d1 and 5x2x2 tensor d2 and I want to realize if(a<b) d1=d2. If I insert a<b in the multiplication then there’ll be a dimension mismatch. I want to find a way to realize a “if” on gpus or tensor bool type varible. Thank you anyway!

If the dims don’t match, then you’ll need to first establish how you want define what a<b means. There are always methods to get the dims to match, but we need to establish what you want that inequality to mean in this case.