I tried to use torch.max on a [1000, 1000, 3, 256] tensor along the dim=2, but got the CUDA out of memory error. Why does torch.max use so much GPU memory?
Code:
import torch
a = torch.rand(1000, 1000, 3, 256).cuda()
b = a.max(dim=2)[0]
And the error message:
RuntimeError: CUDA out of memory. Tried to allocate 15.26 GiB (GPU 0; 23.70 GiB total capacity; 5.72 GiB already allocated; 13.82 GiB free; 5.72 GiB reserved in total by PyTorch)
The tensor a uses 2.8G memory, but torch.max requires 15.26GB memory.