Is this a bug in torch.nn.EmbeddingBag?

EmbeddingBag are filled with Nan in GPU for empty bags, but zeros in CPU.
It should be zeros according to the pytorch document.


here is the test code:

u_embedding = nn.EmbeddingBag(180,100)
offset = torch.LongTensor([0,2,2])
word_in = torch.LongTensor([234,234,23,234,53])
out = u_embedding(word_in,offset)
#out[1] is zeros

u_embedding = nn.EmbeddingBag(180,100).cuda()
offset = torch.cuda.LongTensor([0,2,2])
word_in = torch.cuda.LongTensor([234,234,23,234,53])
out = u_embedding(word_in,offset)
#out[1] is Nan

Is this a bug? what should i do if want to using GPU and set zeros by default for empty bags?
Is there a setup like ''padding_idx" in nn.EmbeddingBag layers?

Yes it’s a bug. I’m fixing. See https://github.com/pytorch/pytorch/issues/11739

Thank you,

but what should i do to use the new version as early as possible.
Or is there any other way to set ‘padding_idx’ in nn.EmbeddingBag layers?

I saw the code to fix this bug is very few.

Can I fix this problem by changing few of the pytorch package codes?

The fix has some issues to be smoothed out, so it hasn’t been merged yet. Since it involves c++ change, you would have to compile from source to make it work.

FYI fix just landed on master

I successfully installed it from source, it works.

Thank you very much.

Disturb you again.

After i update pytorch from source, i got RuntimeError: cuda runtime error (9) : invalid configuration argument at /data/users/sdf/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:257 when run my code.

and when i replace all cuda() with cpu(), it works perfectly. There may be other bug exist in EmbeddingBag GPU codes.

here is the test code:

import torch.optim as optim
import torch
import torch.nn as nn
import numpy as np
from scipy.special import expit
import os
import time

class SkipGramModel(nn.Module):
    def __init__(self, component_size, word_size, dim):
        super(SkipGramModel, self).__init__()
        self.emb_size = dim
        self.component_size = component_size
        self.word_size = word_size
        # atten = torch.zeros([word_size,5])
        # atten[:,0] += torch.log(torch.FloatTensor([4]))
        # self.atten = nn.Parameter(atten,requires_grad=True)
        self.atten_layers = nn.Embedding(word_size,1)
        self.u_embeddings = nn.EmbeddingBag(component_size,dim)
        self.word_embeddings = nn.Embedding(word_size,dim,sparse=True)
        self.v_embeddings = nn.Embedding(word_size,dim,sparse=True)
        # self.attention_matrix = 0.5 * torch.ones(self.word_size, 1).cuda()
        self.m = nn.Sigmoid()
        self.init_emb()

    def init_emb(self):
        initrange = 0.5 / self.emb_size
        self.word_embeddings.weight.data.uniform_(-initrange,initrange)
        self.u_embeddings.weight.data.uniform_(-initrange, initrange)
        self.v_embeddings.weight.data.uniform_(-0, 0)
        atten = torch.zeros([self.word_size, 5])
        atten[:, 0] += torch.log(torch.FloatTensor([4]))
        self.atten_layers.weight.data = atten


    def forward(self, word_in,component_in, word_out, offset):
        char_in = torch.cuda.LongTensor(component_in[0])
        redical_in = torch.cuda.LongTensor(component_in[1])
        com1_in = torch.cuda.LongTensor(component_in[2])
        com2_in = torch.cuda.LongTensor(component_in[3])
        offset1 = torch.cuda.LongTensor(offset[0])
        offset2 = torch.cuda.LongTensor(offset[1])
        offset3 = torch.cuda.LongTensor(offset[2])
        offset4 = torch.cuda.LongTensor(offset[3])
        attention = torch.softmax(self.atten_layers(word_in),dim=-1).unsqueeze(1)
        emb_uword = self.word_embeddings(word_in)
        emb_char = self.u_embeddings(char_in,offset1)
        emb_redical = self.u_embeddings(redical_in,offset2)
        emb_com1 = self.u_embeddings(com1_in,offset3)
        emb_com2 = self.u_embeddings(com2_in,offset4)
        emb_all = torch.stack((emb_uword,emb_char,emb_redical,emb_com1,emb_com2),1)
        emb_vword = self.v_embeddings(word_out)
        emb_mixin = torch.bmm(attention,emb_all).squeeze(1)
        score = torch.mul(emb_mixin, emb_vword)
        score = torch.sum(score, dim=-1)
        score = self.m(score)
        return score

if __name__ == '__main__':

    model = SkipGramModel(364, 180, 100).cuda()
    optimizer = optim.SGD(model.parameters(), lr=0.025)
    Lossfunc = nn.BCELoss(reduction='sum')
    for _ in range(100):
        word_in = torch.cuda.LongTensor([2]*128)
        word_out = torch.cuda.LongTensor([2]*128)
        label = torch.cuda.FloatTensor([1]*128)
        component_in = [[3,5],[2,4,5],[2,3,4],[]]
        offset = [[0]*127+[1],[0]*127+[1],[0]*128,[0]*128]
        outs = model.forward(word_in, component_in, word_out, offset)
        loss = Lossfunc(outs, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

I tried to python setup.py install to reshow your problem. But I got No module named 'tools.setup_helpers'
Then I just use my pytorch ( 0.4.0 on win10) and it occurs an alike problem.
RuntimeError: cuda runtime error (9) : invalid configuration argument at C:/Users/Administrator/Downloads/new-builder/win-wheel/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:281

Due to my version is old, this issue should not be caused by the fix of last issue.