Filter Output Using Mask

I have an output of shape 14 x 10 x 128, where 14 is the batch_size, 10 is the sequence_length, and 128 is the object vector representing the objects associated with each sequence element.

Now, not all the sequence elements are relevant. For example, if we look at the first batch element (10 x 128), the sequence in this is made up of only 3 elements, i.e only 3 x 128 are useful, whereas the rest 7 x 128 are just padded elements. I have the mask that has this information is of shape 14 x 10. How do I filter the output so that I get only the “relevant” output using the mask?

For example (shapes simplified to keep it succinct):

outputs =  torch.tensor([[[0.2, 0.3], [0.1, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 0.11], [0.14, 0.15]]])
masks = torch.tensor([[1,1,0], [1,0,0]]) # first batch element, only two elements in seq are relevant, and only one in second

# do I first expand mask?
masks = masks[:, :, None].expand_as(outputs)

# then get the required outputs?
outputs = outputs * masks

You can either do how you proposed (also without expanding_as(outputs) since broadcasting takes care of it) or in the following ways. You can pick your favourite :wink:

# shape (14,10,128)
outputs =  torch.tensor(
                         [
                          [
                           [0.2, 0.3],
                           [0.1, 0.4],
                           [0.5, 0.6]
                          ],
                          [
                           [0.7, 0.8],
                           [0.9, 0.11],
                           [0.14, 0.15]
                          ]
                         ]
)
# shape (14,10)
masks = torch.tensor(
                     [
                      [1,1,0],
                      [1,0,0]
                     ]
) # first batch element, only two elements in seq are relevant, and only one in second

# Einsum
masked_outptus = torch.einsum("ijk,ij->ijk", outputs, masks)
# or unsqueezed + mul
masked_outptus = torch.mul(outputs, masks.unsqueeze(-1)) # outputs * masks.unsqueeze(-1)
1 Like