Hello! I’m trying to perform an average pool over a sequence ignoring the padding elements. So far I have:
def avg_pool(data, input_lens: Optional[torch.LongTensor] = None):
"""
A 1d avg pool for sequence data
Args:
data: of dim (batch, seq_len, hidden_size)
input_lens: Optional long tensor of dim (batch,) that represents the
original lengths without padding. Tokens past these lengths will not
be included in the average.
Returns:
Tensor (batch, hidden_size)
"""
if input_lens is not None:
return torch.stack([
torch.sum(data[i, :l, :], dim=0) / l for i, l in enumerate(input_lens)
])
else:
return torch.sum(data, dim=1) / float(data.shape[1])
This works fine but the use of the array comprehension feels a little hackish. I’m still trying to learn all the fancy ways of indexing, so I was wondering if anyone knew of a nice vectorized way to do this or does this have to be a loop?