Performance of IF statements in Pytorch

I use this toy example to measure performance of if statements in the forward loop.

At least for XLA devices (such as in COLAB) when the conditional statement is fixed the performance doesn’t seem to be affected (e.g. model properties). However when the conditional statement is highly dependant on input the performance is affected.

Are conditions handled in GPU or the code is optimized for GPU? Is the optimization and conversion of the code taking all of that into account?

Is there a general rule for code level optimization for IF statements or it doesn’t really matter?

1 Like


The main problem with conditionals is that they are handled on the python side and so the values needs to be on the CPU.
So if you use an accelerator like GPU or TPU, the CPU has to wait for the value to be computed and then it can go through the conditional and continue the execution. This is sub-optimal but cannot really be worked around.

cc @ailzhang there might be other details about TPU that I don’t know about?

1 Like