Efficient average of subpart of a vector

I’m looking for an operator to compute averages of vectors given a matrix and a list of offsets:

input = Variable(torch.LongTensor([[1,1,1],
                                   [2,2,2],
                                   [3,3,3],
                                   [4,4,4],
                                   [5,5,5],
                                   [6,6,6]]))
offsets = Variable(torch.LongTensor([0,4]))
out = my_operator(input, offsets, mode='mean')
#Variable containing:
#  2.5  2.5   2.5   <- the four first vectors are averaged together
#  5.5  5.5   5.5   <- the last two are averaged together
#[torch.FloatTensor of size 2x3]

Considering the nn.EmbeddingBag module which can take a list of indices and offsets and return a vector of averaged embeddings. One could rewrite this module using the previously defined operator:

def EmbeddingBag(emb, input, offsets, mode='mean'):
    embedded_input = emb(input.unsqueeze(0))
    return my_operator(embedded_input, offsets, mode=mode)

input = Variable(torch.LongTensor([1,2,4,5,4,3,2,9]))
offsets = Variable(torch.LongTensor([0,4]))
emb = nn.Embedding(10, 3)
EmbeddingBag(emb, input, offsets)

Is there a way to get this operator? Is it available in pytorch ? in torch ? I feel it could be very useful in many situations.

1 Like

I’m trying to hack the nn.EmbeddingBag module to get the desired effect:

import torch
import torch.nn as nn
from torch.autograd import Variable


class TensorReducer(nn.Module):
    
    def __init__(self):
        super(TensorReducer, self).__init__()
        self.emb = nn.EmbeddingBag(0,0)
        
    def set_embs(self, embeddings):
        self.emb.weight = nn.Parameter(embeddings.data)
        
    def get_dummy_input(self, size):
        return Variable(torch.LongTensor(list(range(size))))
    
    def forward(self, inputs, offsets):
        self.set_embs(inputs)
        dummy = self.get_dummy_input(inputs.size(0))
        return self.emb(dummy, offsets)
    

red = TensorReducer()

inputs = Variable(torch.FloatTensor([[1,1,1],[2,2,2],[3,3,3],
                                     [4,4,4],[5,5,5],[6,6,6]]))
offsets = Variable(torch.LongTensor([0,4]))
red(inputs, offsets)
#  Variable containing:
#   2.5000  2.5000  2.5000
#   5.5000  5.5000  5.5000
#[torch.FloatTensor of size 2x3]

The issue is that when I’m setting the internal embedding matrix inside the EmbeddingBag I’m breaking the ability of the gradient to backpropagate in earlier module. Any solution for this?

What about using directly nn.functionnal.embedding_bag:

weights = Variable(torch.FloatTensor([[1,1,1],[2,2,2],[3,3,3],
                                       [4,4,4],[5,5,5],[6,6,6]]))
offsets = Variable(torch.LongTensor([0,4]))
input = Variable(torch.arange(weights.size(0)).long())

nn.functional.embedding_bag(weights, input, offsets, None, 2, False, 'mean')

Note that you need at least pytorch v3 to use it.

2 Likes