Hi everyone,

I have encountered a weird behaviour of torch.max(). I found out that the following code is the most time consuming part of my code. It takes 0.042899s to excute torch.max function, where grad_norm is a cuda float tensor with shape (1024, ).

```
start = time.time()
temp = torch.max(grad_norm)
end = time.time()
print('NN computation time %f' % (end - start))
```

Then I slightly modified the code to

```
temp, _ = torch.max(grad_norm, dim = 0)
```

It became much more faster (0.000020s, x2000 times faster). Later I found out that

```
temp = temp.item()
```

is nearly as slow as computing torch.max(). Is it normal? Are there any explanations for that?