Hi,

I have a mask vector of binary values, I would like to use this to essentially mask rows in a matrix:

```
mask = [1, 0, 1]
matrix = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
out = mask * matrix
out = [[1, 2, 3],
[0, 0, 0],
[7, 8, 9]]
```

I am trying to figure out how I can implement the `out = mask * matrix`

operation when I have a batch of different masks and 1 matrix, so shapes are:

```
mask.shape = [batch, n]
matrix.shape = [n, m]
out.shape = [batch, n, m]
```

I think I might have to use torch.repeat to repeat the matrix over the batch dim but not sure.