How to multiply multiple masks with rgb image?

I have a mask tensor which is in shape of BxCxHxW where C represents the number of masks (C=10). I also have an input tensor that is in shape of Bx3xHxW.
I wonder if I can do the following without loop to be faster…

how I can multiply my input with the mask so that I have output of shape C*Bx3xHxW.
where output[:C,...] is the output of input[0,...] times mask[0,...], and output[C:2C] is equal to input[1,...] times mask[1,...], and so on…

In the loop form it would be like:

output = torch.zeros(B*C,3,H,W)
for b in range(B):
   for i in range(C):
      output[i + b*C] = torch.mul(input[b].unsqueeze(0), 
                                     mask[b,i].unsqueeze(0).unsqueeze(0))

I think this is a natural place to do an einsum as it is effectively a variation of an outer-product style computation:

import torch
import time

B = 16
C = 64
H = 32
W = 32

input = torch.randn(B, 3, H, W)
mask = torch.randn(B, C, H, W)

output = torch.zeros(B*C,3,H,W)
t1 = time.time()
for b in range(B):
   for i in range(C):
      output[i + b*C] = torch.mul(input[b].unsqueeze(0),
                                    mask[b,i].unsqueeze(0).unsqueeze(0))
t2 = time.time()
output2 = torch.einsum('blhw,bchw->bclhw', input, mask)
output2 = output2.reshape(B*C, 3, H, W)
t3 = time.time()
print(t2-t1, t3-t2)
print(torch.allclose(output, output2))
0.0157625675201416 0.0012059211730957031
True
1 Like