I have an output of shape 14 x 10 x 128
, where 14
is the batch_size
, 10
is the sequence_length
, and 128
is the object vector representing the objects associated with each sequence element.
Now, not all the sequence elements are relevant. For example, if we look at the first batch element (10 x 128
), the sequence in this is made up of only 3
elements, i.e only 3 x 128
are useful, whereas the rest 7 x 128
are just padded elements. I have the mask that has this information is of shape 14 x 10
. How do I filter the output so that I get only the “relevant” output using the mask?
For example (shapes simplified to keep it succinct):
outputs = torch.tensor([[[0.2, 0.3], [0.1, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 0.11], [0.14, 0.15]]])
masks = torch.tensor([[1,1,0], [1,0,0]]) # first batch element, only two elements in seq are relevant, and only one in second
# do I first expand mask?
masks = masks[:, :, None].expand_as(outputs)
# then get the required outputs?
outputs = outputs * masks