Speed of torch.max

Hi everyone,

I have encountered a weird behaviour of torch.max(). I found out that the following code is the most time consuming part of my code. It takes 0.042899s to excute torch.max function, where grad_norm is a cuda float tensor with shape (1024, ).

start = time.time()

temp = torch.max(grad_norm)

end = time.time()
print('NN computation time %f' % (end - start))

Then I slightly modified the code to

temp, _ = torch.max(grad_norm, dim = 0)

It became much more faster (0.000020s, x2000 times faster). Later I found out that

temp = temp.item()

is nearly as slow as computing torch.max(). Is it normal? Are there any explanations for that?

The answer to your question is the same reason as why someone thought .item() was the slowest operation: Tensor.item() takes a lot of running time

1 Like

Thanks for clarification! Seems like it is synchronization thing resulting inaccurate timing of .item() operation.

Since torch.max() return a int number, does that mean it will call .item() in the end and that is the reason why torch.max seems “slow”?

yes, your understanding is correct.

1 Like