Best technique for mean pooling middle tokens embeddings for transformers

A common technique for certain nlp tasks are to mean pool sentences or entities which span several tokens. So in some cases, the mean of a slice of the final context embeddings is calculated. This is easy to do for a single row, but less obvious to calculate in a batch.

Say that I want to take the mean pool of a couple entities, each spanning a few tokens.

I have last_hidden_state which is a 512 tokens x 768 embedding size tensor.

I also have mean_slices which is a n x 2 sized tensor, where n is the number of entities within the 512 tokens. Each row contains two numbers, they are the start and end token index for each entity.

Since there are no ragged tensors in pytorch, I came up with a workaround where I would calculate the cumsum of the embeddings along the axis of the 512 tokens, and then for each cumsum of an end entity index, I would subtract out the cumsum before the start entity index, and then divide by the length of the slices.

Here’s how it would look like in pytorch code

padded = torch.nn.functional.pad(
    last_hidden_state.cumsum(dim=0), (0, 0, 1, 0)
)

pools = torch.diff(
    padded[mean_slices], dim=1
).squeeze()/torch.diff(mean_slices, dim=1)

However, I am wondering if this may not work as intended, as gradients will be passed to all tokens, not just the ones being averaged…I think.

1 Like

I tested on pytorch, using code like this

last_hidden_state = torch.randn(5, 2, requires_grad=True)
mean_slices = torch.tensor([[0, 3], [3,4]])

padded = torch.nn.functional.pad(
    last_hidden_state.cumsum(dim=0), (0, 0, 1, 0)
)

pools = torch.diff(
    padded[mean_slices], dim=1
).squeeze()/torch.diff(mean_slices, dim=1)

pools.sum().sum().backward()
#outputs
"""
tensor([[0.33, 0.33],
        [0.33, 0.33],
        [0.33, 0.33],
        [1.00, 1.00],
        [0.00, 0.00]])
"""
#when I do this without using cumsum
last_hidden_state.grad.data.zero_()  # zero gradient
span1 = last_hidden_state[0:3, ...].mean(0, keepdim=True)
span2 = last_hidden_state[3:4, ...].mean(0, keepdim=True)
spans = torch.cat([span1, span2], dim=0)
spans.sum().sum().backward()
#outputs
"""
tensor([[0.33, 0.33],
        [0.33, 0.33],
        [0.33, 0.33],
        [1.00, 1.00],
        [0.00, 0.00]])
"""

so I think their output are exactly the same, tokens will receive 1/span_len gradient of its belonging span
, and other tokens outside any span will not receive gradient.
Your solution works fine, feel free to use it.

1 Like