Hi,
I have a mask vector of binary values, I would like to use this to essentially mask rows in a matrix:
mask = [1, 0, 1]
matrix = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
out = mask * matrix
out = [[1, 2, 3],
[0, 0, 0],
[7, 8, 9]]
I am trying to figure out how I can implement the out = mask * matrix
operation when I have a batch of different masks and 1 matrix, so shapes are:
mask.shape = [batch, n]
matrix.shape = [n, m]
out.shape = [batch, n, m]
I think I might have to use torch.repeat to repeat the matrix over the batch dim but not sure.