In my code I perform an iterative procedure (for n in range(iterations): …) and in every iteration I create the long tensor unscaled_h_star_grad on the GPU. If want to stop if its norm is small, so I wrote
Dfnulambda_grad = 1 / N * torch.norm(unscaled_h_star_grad)**2
cond3 = Dfnulambda_grad < 1e-6
if cond3:
print(f’time derivative of the objective = {Dfnulambda_grad.item()} < 1e-6. Stopping iteration’)
break
However, this takes a lot of time, it contributes roughly a third of the total run time of my algorithm.
Apart from only checking this condition only say, every 25th iteration, is there a way to handle this better? I suspect there is some issue with GPU-CPU transfer or synchronization here.
Kind regards