I have encountered a weird behaviour of torch.max(). I found out that the following code is the most time consuming part of my code. It takes 0.042899s to excute torch.max function, where grad_norm is a cuda float tensor with shape (1024, ).
start = time.time() temp = torch.max(grad_norm) end = time.time() print('NN computation time %f' % (end - start))
Then I slightly modified the code to
temp, _ = torch.max(grad_norm, dim = 0)
It became much more faster (0.000020s, x2000 times faster). Later I found out that
temp = temp.item()
is nearly as slow as computing torch.max(). Is it normal? Are there any explanations for that?