Checking termination condition is exceptionally slow

In my code I perform an iterative procedure (for n in range(iterations): …) and in every iteration I create the long tensor unscaled_h_star_grad on the GPU. If want to stop if its norm is small, so I wrote

Dfnulambda_grad = 1 / N * torch.norm(unscaled_h_star_grad)**2
cond3 = Dfnulambda_grad < 1e-6
if cond3:
print(f’time derivative of the objective = {Dfnulambda_grad.item()} < 1e-6. Stopping iteration’)
break

However, this takes a lot of time, it contributes roughly a third of the total run time of my algorithm.
Apart from only checking this condition only say, every 25th iteration, is there a way to handle this better? I suspect there is some issue with GPU-CPU transfer or synchronization here.

Kind regards

PyTorch executes CUDA kernels asynchronously.
Your code uses a data dependent condition is is thus synchronizing your code and accumulates the GPU execution time from already scheduled and launched kernels.

1 Like

Ok, understood, thank you for your fast response. I guess that checking if some loss is small enough in every k-th iteration is a quite common thing to do. What is the runtime optimal way to execute it?