How can I do a masked scatter operation without a memory allocation?

Hi all,

Basically, I would like to do this in CUDA in eager mode, without performing a memory allocation:

dest[mask] = src[mask]

Unfortunately, this does a memory allocation. You can confirm this as follows:

import torch

dest = torch.tensor([-1, -1, -1, -1, -1], device="cuda")
src  = torch.tensor([0, 1, 2, 3, 4], device="cuda")
mask = torch.tensor([True, False, False, False, True], device="cuda")


dest[mask] = src[mask]

This outputs RuntimeError: called a synchronizing CUDA operation.

I understand why this does an allocation, at least in eager mode. Since the execution is imperative, there is no way to know that the output of executing src[mask] can be written directly into the relevant coefficients of dest[mask], since they are both indexed by the same mask. Therefore, an allocation must happen, and of course this allocation is dependent upon the number of “True” values in mask.

I originally thought that I should use torch.Tensor.masked_scatter_ — PyTorch 2.1 documentation, but this does not perform a “scatter” as most people think of it, without going into details. This stack overflow question explains in more detail: pytorch - torch.masked_scatter result did not meet expectations - Stack Overflow

The motivation is that I want to use cuda graphs to speed up a particular model inference that has several short-running cuda kernels (so much so that 90% of the time is spent on the CPU). However, without a way to do a “masked select” that avoids a memory allocation, which is synchronous, I seem to be out of luck.

torch.index_select and friends don’t work because I need to create the index tensor in the first place from the mask tensor, which involves doing a memory allocation (since I need to count the number of instances of True in the mask tensor).


The problem is indeed that the operation you’re doing depends on the value of the content.
You can avoid that by doing Tensor-wide ops (that will of course not be as efficient for very sparse mask):

res = dest - dest * mask + src * mask

You can make that into inplace to reduce extra allocations as needed. But that won’t trip debug mode/cudagraph for sure.

Actually, someone offline told me that torch.where(mask, src, dest, out=dest) is exactly what I need, so that is what I went with.