Why does torch.max use so much GPU memory?

I tried to use torch.max on a [1000, 1000, 3, 256] tensor along the dim=2, but got the CUDA out of memory error. Why does torch.max use so much GPU memory?

Code:

import torch
a = torch.rand(1000, 1000, 3, 256).cuda()
b = a.max(dim=2)[0]

And the error message:

RuntimeError: CUDA out of memory. Tried to allocate 15.26 GiB (GPU 0; 23.70 GiB total capacity; 5.72 GiB already allocated; 13.82 GiB free; 5.72 GiB reserved in total by PyTorch)

The tensor a uses 2.8G memory, but torch.max requires 15.26GB memory.

Very interesting, I can reproduce this on 3080 and a recent build. I’ll see if I can find more about the underlying cause.

We’ve opened an issue for this:
min/max require a huge allocation on GPU · Issue #63869 · pytorch/pytorch (github.com)

For now, if you don’t need to use the indices, a workaround would be to use amax rather than max here:

import torch
a = torch.rand(1000, 1000, 3, 256).cuda()
b = a.amax(dim=2)

Yes, amax uses less memory, but it is much slower during training.