Unreasonable memory consumption by torch.masked.sum with broadcasted tensors

Consider the following code, which throws OutOfMemoryError:

import torch

a = torch.randn(1000000, 1, device="cuda")
mask = torch.randint(0, 2, (1, 1000000), device="cuda", dtype=torch.bool)
a_, mask_ = torch.broadcast_tensors(a, mask)
b = torch.masked.sum(input=a_, mask=mask_, dim=1)
# b = torch.masked.masked_tensor(data=a_, mask=mask_).sum(1)
# b = torch.masked.as_masked_tensor(data=a_, mask=mask_).sum(1)

This throws the following error:

File [...], line 6, in <module>
    b = torch.masked.sum(input=a_, mask=mask_, dim=1)
[...]
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3725.29 GiB
[...]

I expect these to be the memory usage:

  1. a: 4,000,000B → 4MB
  2. mask: 1MB
  3. The (manually) broadcasted tensors take up no additional memory
  4. b is of same shape as a, so takes 4MB too
  5. If cloned, a_ would take up 4 * 1000000 ** 2B → 3725.29 GiB

It seems that torch.masked.sum internally clones a_, leading to that catastrophic vRAM request.
As you may see, I’ve formulated similar expressions with masked tensors, but it seems their construction process currently requires a cloning that fully expands the memory usage…

Is there any way to avoid a memory explosion for such a “masked sum” operation?