A common technique for certain nlp tasks are to mean pool sentences or entities which span several tokens. So in some cases, the mean of a slice of the final context embeddings is calculated. This is easy to do for a single row, but less obvious to calculate in a batch.
Say that I want to take the mean pool of a couple entities, each spanning a few tokens.
I have last_hidden_state
which is a 512 tokens x 768 embedding size tensor.
I also have mean_slices
which is a n x 2 sized tensor, where n is the number of entities within the 512 tokens. Each row contains two numbers, they are the start and end token index for each entity.
Since there are no ragged tensors in pytorch, I came up with a workaround where I would calculate the cumsum of the embeddings along the axis of the 512 tokens, and then for each cumsum of an end entity index, I would subtract out the cumsum before the start entity index, and then divide by the length of the slices.
Here’s how it would look like in pytorch code
padded = torch.nn.functional.pad(
last_hidden_state.cumsum(dim=0), (0, 0, 1, 0)
)
pools = torch.diff(
padded[mean_slices], dim=1
).squeeze()/torch.diff(mean_slices, dim=1)
However, I am wondering if this may not work as intended, as gradients will be passed to all tokens, not just the ones being averaged…I think.