Mask rows in a matrix

Hi,

I have a mask vector of binary values, I would like to use this to essentially mask rows in a matrix:

mask = [1, 0, 1]
matrix = [[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]]

out = mask * matrix
out =  [[1, 2, 3],
        [0, 0, 0],
        [7, 8, 9]]

I am trying to figure out how I can implement the out = mask * matrix operation when I have a batch of different masks and 1 matrix, so shapes are:

mask.shape  = [batch, n]
matrix.shape = [n, m]
out.shape = [batch, n, m]

I think I might have to use torch.repeat to repeat the matrix over the batch dim but not sure.

I actually figured this out straight away. All you need to do is unsqueeze the mask vector and repeat the matrix over the batch dimension:

batch_size = 2
a = torch.randn(batch_size, 3)
b = torch.randn(3, 4)

a = a.unsqueeze(-1)

b = b.unsqueeze(0)
b = b.repeat(batch_size, 1, 1)

print(a.shape, b.shape)
>>> torch.Size([2, 3, 1]) torch.Size([2, 3, 4])

c = (a * b)
print(c.shape)
>>> torch.Size([2, 3, 4])