I have an output of shape
14 x 10 x 128, where
14 is the
10 is the
128 is the object vector representing the objects associated with each sequence element.
Now, not all the sequence elements are relevant. For example, if we look at the first batch element (
10 x 128), the sequence in this is made up of only
3 elements, i.e only
3 x 128 are useful, whereas the rest
7 x 128 are just padded elements. I have the mask that has this information is of shape
14 x 10. How do I filter the output so that I get only the “relevant” output using the mask?
For example (shapes simplified to keep it succinct):
outputs = torch.tensor([[[0.2, 0.3], [0.1, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 0.11], [0.14, 0.15]]]) masks = torch.tensor([[1,1,0], [1,0,0]]) # first batch element, only two elements in seq are relevant, and only one in second # do I first expand mask? masks = masks[:, :, None].expand_as(outputs) # then get the required outputs? outputs = outputs * masks