Consider the following code, which throws OutOfMemoryError
:
import torch
a = torch.randn(1000000, 1, device="cuda")
mask = torch.randint(0, 2, (1, 1000000), device="cuda", dtype=torch.bool)
a_, mask_ = torch.broadcast_tensors(a, mask)
b = torch.masked.sum(input=a_, mask=mask_, dim=1)
# b = torch.masked.masked_tensor(data=a_, mask=mask_).sum(1)
# b = torch.masked.as_masked_tensor(data=a_, mask=mask_).sum(1)
This throws the following error:
File [...], line 6, in <module>
b = torch.masked.sum(input=a_, mask=mask_, dim=1)
[...]
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3725.29 GiB
[...]
I expect these to be the memory usage:
a
: 4,000,000B → 4MBmask
: 1MB- The (manually) broadcasted tensors take up no additional memory
b
is of same shape asa
, so takes 4MB too- If cloned,
a_
would take up 4 * 1000000 ** 2B → 3725.29 GiB
It seems that torch.masked.sum
internally clones a_
, leading to that catastrophic vRAM request.
As you may see, I’ve formulated similar expressions with masked tensors, but it seems their construction process currently requires a cloning that fully expands the memory usage…
Is there any way to avoid a memory explosion for such a “masked sum” operation?