`masked_scatter` slower than `where`?

If I want to replace a tensor x partially by y based on a condition (a Boolean tensor), I think the most suitable method is to call masked_scatter since it is designed for such usage. However, I found torch.where is even faster?

import torch
import cProfile

n=100000
dim=2000
device='cuda'

# masked_scatter_
x = torch.rand(n, dim, device=device)
y =  torch.rand(n, dim, device=device)
condition = torch.rand(n, dim, device=device) > 0.5
cProfile.run('x.masked_scatter_(condition, y)')


# torch.where
x = torch.rand(n, dim, device=device)
y =  torch.rand(n, dim, device=device)
condition = torch.rand(n, dim, device=device) > 0.5
cProfile.run('x = torch.where(condition, x, y)')

What I got is

 4 function calls in 0.112 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.112    0.112 <string>:1(<module>)
        1    0.000    0.000    0.112    0.112 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        1    0.112    0.112    0.112    0.112 {method 'masked_scatter_' of 'torch._C._TensorBase' objects}


4 function calls in 0.000 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method where}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

Using CPU has a similar result. Besides, torch.where has a lower memory usage (not shown in the result above) despite I am using a inplace masked_scatter_. Why would this happen?

Note that CUDA operations are executed asynchronously, so you would have to synchronize the code before starting and stopping timers.
I’m not deeply familiar with cProfile, but assume that it doesn’t synchronize the GPU internally.
The torch.utils.benchmark utils might be useful to compare different methods, as they will add warmup iterations and synchronize internally.

Thanks for your reply. I tried torch.utils.benchmark but I got similar results

import torch
n=100000
dim=2000
device='cuda'
import torch.utils.benchmark as benchmark

# masked_scatter_
x = torch.rand(n, dim, device=device)
y =  torch.rand(n, dim, device=device)
condition = torch.rand(n,1, device=device) > 0.5
print(benchmark.Timer('x.masked_scatter_(condition.expand(n,dim), y)',globals={'x':x,'y':y,'condition':condition,'n':n,'dim':dim}).timeit(10))

# torch.where
x = torch.rand(n, dim, device=device)
y =  torch.rand(n, dim, device=device)
condition = torch.rand(n,1, device=device) > 0.5
print(benchmark.Timer('torch.where(condition.expand(n,dim), x, y)',globals={'x':x,'y':y,'condition':condition,'n':n,'dim':dim}).timeit(10))

Output:

<torch.utils.benchmark.utils.common.Measurement object at 0x0000019B0E340670>
x.masked_scatter_(condition.expand(n,dim), y)
  22.45 ms
  1 measurement, 10 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x0000019B097ADCA0>
torch.where(condition.expand(n,dim), x, y)
  5.53 ms
  1 measurement, 10 runs , 1 thread

Hello,

I have no answer but I made a similar observation today.
I compared tensor indexing vs masked_scatter:

x[condition] = y[condition]
# VS
x.masked_scatter_(condition.view(-1, 1), y)

and indeed masked_scatter is much slower.

Additionally, I found the solution using tensor indexing to be faster than torch.where by about 20%, and to be more readable.


On my machine, I get:

<torch.utils.benchmark.utils.common.Measurement object at 0x7f7845fc4b70>
x.masked_scatter_(condition.view(-1,1), y)
  28.40 ms
  1 measurement, 100 runs , 1 thread

<torch.utils.benchmark.utils.common.Measurement object at 0x7f7845fc4ac8>
torch.where(condition.view(-1,1), x, y)
  6.19 ms
  1 measurement, 100 runs , 1 thread

<torch.utils.benchmark.utils.common.Measurement object at 0x7f7845fc4828>
x[condition] = y[condition]
  4.79 ms
  1 measurement, 100 runs , 1 thread

Yes, basically the same as what I got. Comparing where and x[condition] = y[condition], the advantage of where is that it has lower memory usage. I do not know how to get the benchmark, but I observed from my task manager.