Suppose I have two tensor type scalars a,b. I want to compare them in a “if” statement, but when I use “if a<b” then there would be a detach procedure taking around 70ms on my machine. Is there an alternative to realize this? a<b is a tensor bool type and I don’t know how to combine with “if” without detach. Thank you!
The if
statement should not detach the tensors but synchronize the code, since you have introduced data-dependent control flow. The host needs to synchronize to be able to read the actual values in a
and b
before evaluating the if
condition.
There is a torch.where
statement that is the if equivalent for tensors.
Usage:
a = torch.rand((128, 3, 64, 64))
b = torch.rand((128, 3, 64, 64))
c = torch.where(a<b, a, b)
https://pytorch.org/docs/stable/generated/torch.where.html
Alternatively, if you need to maintain the graph, you can either write custom backward functions, as described here:
OR construct your logic to use functions that have backward well defined. For example:
b_larger = a<b
c = b_larger*a + (1-b_larger)*b
Thank you! But what I’m going to realize after (a<b) is a bit troublesome. Suppose I have a 3x2x2 tensor d1 and 5x2x2 tensor d2 and I want to realize if(a<b) d1=d2. If I insert a<b in the multiplication then there’ll be a dimension mismatch. I want to find a way to realize a “if” on gpus or tensor bool type varible. Thank you anyway!
If the dims don’t match, then you’ll need to first establish how you want define what a<b
means. There are always methods to get the dims to match, but we need to establish what you want that inequality to mean in this case.