Variable length sequence avg pool / How to slice based off a long tensor of slice_ends?

Hello! I’m trying to perform an average pool over a sequence ignoring the padding elements. So far I have:

def avg_pool(data, input_lens: Optional[torch.LongTensor] = None):
    """
    A 1d avg pool for sequence data
    Args:
        data: of dim (batch, seq_len, hidden_size)
        input_lens: Optional long tensor of dim (batch,) that represents the
            original lengths without padding. Tokens past these lengths will not
            be included in the average.

    Returns:
        Tensor (batch, hidden_size)

    """
    if input_lens is not None:
        return torch.stack([
            torch.sum(data[i, :l, :], dim=0) / l for i, l in enumerate(input_lens)
        ])
    else:
        return torch.sum(data, dim=1) / float(data.shape[1])

This works fine but the use of the array comprehension feels a little hackish. I’m still trying to learn all the fancy ways of indexing, so I was wondering if anyone knew of a nice vectorized way to do this or does this have to be a loop?

To elaborate more, fundamentally I am looking for some operation that lets you slice based of a LongTensor of start or end points. So like a “assign_slices” method on tensors which takes in (starts, ends, new_val, dim)

>>> a = torch.tensor([[1,2,3],[4,5,6]])
>>> a.assign_slices(torch.tensor([2,1], None, 0, dim=0)
torch.tensor([[1, 2, 0], [4, 0, 0])
>>> a.assign_slices(None, torch.tensor([2,1]), 99, dim=0)
torch.tensor([[99, 99, 3], [99, 5, 6])
>>> a.assign_slices( torch.tensor([1, 0]), torch.tensor([2, 2]), 0, dim=0)
torch.tensor([[1, 0, 3], [0, 0, 6])

Does something like this exist?

If it does then I can just make everything after the input_lens be 0 and then sum as normal to get my average.

I found myself suddenly again wanting this same functionality when trying to make a model predict spans of tokens where the end of the span was conditioned to be after the start (so basically the starts was a long tensor, and everything before it I wanted to set a large negative value).

This code might look a bit hacky, but should be a bit faster than your current loop:

N, C, H = 64, 100, 32
data = torch.randn(N, C, H)
input_lens = torch.randint(2, H, (N,))

def fun1(data, input_lens):
    ret = torch.stack([
        torch.sum(data[i, :l, :], dim=0) / l for i, l in enumerate(input_lens)
    ])
    return ret

def fun2(data, input_lens):
    idx = torch.arange(C).unsqueeze(0).expand(N, -1)
    idx = idx < input_lens.unsqueeze(1)
    idx = idx.unsqueeze(2).expand(-1, -1, H)
    ret = (data * idx.float()).sum(1) / input_lens.unsqueeze(1).float()
    return ret

%timeit fun1(data, input_lens)
> 906 µs ± 5.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fun2(data, input_lens)
> 388 µs ± 5.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

I’m not sure what your shapes are, so you might try profiling it for your use case. :wink:

1 Like

@ptrblck Thanks for taking a look at this! Using arange is clever.

It’s somewhat a shame that there’s not a cleaner way of doing this kind of indexing, but I guess it works ¯_(ツ)_/¯.

Thank you!