# Mask rows in a matrix

Hi,

I have a mask vector of binary values, I would like to use this to essentially mask rows in a matrix:

``````mask = [1, 0, 1]
matrix = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]

out =  [[1, 2, 3],
[0, 0, 0],
[7, 8, 9]]
``````

I am trying to figure out how I can implement the `out = mask * matrix` operation when I have a batch of different masks and 1 matrix, so shapes are:

``````mask.shape  = [batch, n]
matrix.shape = [n, m]
out.shape = [batch, n, m]
``````

I think I might have to use torch.repeat to repeat the matrix over the batch dim but not sure.

I actually figured this out straight away. All you need to do is unsqueeze the mask vector and repeat the matrix over the batch dimension:

``````batch_size = 2
a = torch.randn(batch_size, 3)
b = torch.randn(3, 4)

a = a.unsqueeze(-1)

b = b.unsqueeze(0)
b = b.repeat(batch_size, 1, 1)

print(a.shape, b.shape)
>>> torch.Size([2, 3, 1]) torch.Size([2, 3, 4])

c = (a * b)
print(c.shape)
>>> torch.Size([2, 3, 4])
``````