Is this a bug in torch.nn.EmbeddingBag?


#1

EmbeddingBag are filled with Nan in GPU for empty bags, but zeros in CPU.
It should be zeros according to the pytorch document.


here is the test code:

u_embedding = nn.EmbeddingBag(180,100)
offset = torch.LongTensor([0,2,2])
word_in = torch.LongTensor([234,234,23,234,53])
out = u_embedding(word_in,offset)
#out[1] is zeros

u_embedding = nn.EmbeddingBag(180,100).cuda()
offset = torch.cuda.LongTensor([0,2,2])
word_in = torch.cuda.LongTensor([234,234,23,234,53])
out = u_embedding(word_in,offset)
#out[1] is Nan

Is this a bug? what should i do if want to using GPU and set zeros by default for empty bags?
Is there a setup like ''padding_idx" in nn.EmbeddingBag layers?


(Simon Wang) #2

Yes it’s a bug. I’m fixing. See https://github.com/pytorch/pytorch/issues/11739


#3

Thank you,

but what should i do to use the new version as early as possible.
Or is there any other way to set ‘padding_idx’ in nn.EmbeddingBag layers?


#4

I saw the code to fix this bug is very few.

Can I fix this problem by changing few of the pytorch package codes?


(Simon Wang) #5

The fix has some issues to be smoothed out, so it hasn’t been merged yet. Since it involves c++ change, you would have to compile from source to make it work.


(Simon Wang) #6

FYI fix just landed on master


#7

I successfully installed it from source, it works.

Thank you very much.


#8

Disturb you again.

After i update pytorch from source, i got RuntimeError: cuda runtime error (9) : invalid configuration argument at /data/users/sdf/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:257 when run my code.

and when i replace all cuda() with cpu(), it works perfectly. There may be other bug exist in EmbeddingBag GPU codes.

here is the test code:

import torch.optim as optim
import torch
import torch.nn as nn
import numpy as np
from scipy.special import expit
import os
import time

class SkipGramModel(nn.Module):
    def __init__(self, component_size, word_size, dim):
        super(SkipGramModel, self).__init__()
        self.emb_size = dim
        self.component_size = component_size
        self.word_size = word_size
        # atten = torch.zeros([word_size,5])
        # atten[:,0] += torch.log(torch.FloatTensor([4]))
        # self.atten = nn.Parameter(atten,requires_grad=True)
        self.atten_layers = nn.Embedding(word_size,1)
        self.u_embeddings = nn.EmbeddingBag(component_size,dim)
        self.word_embeddings = nn.Embedding(word_size,dim,sparse=True)
        self.v_embeddings = nn.Embedding(word_size,dim,sparse=True)
        # self.attention_matrix = 0.5 * torch.ones(self.word_size, 1).cuda()
        self.m = nn.Sigmoid()
        self.init_emb()

    def init_emb(self):
        initrange = 0.5 / self.emb_size
        self.word_embeddings.weight.data.uniform_(-initrange,initrange)
        self.u_embeddings.weight.data.uniform_(-initrange, initrange)
        self.v_embeddings.weight.data.uniform_(-0, 0)
        atten = torch.zeros([self.word_size, 5])
        atten[:, 0] += torch.log(torch.FloatTensor([4]))
        self.atten_layers.weight.data = atten


    def forward(self, word_in,component_in, word_out, offset):
        char_in = torch.cuda.LongTensor(component_in[0])
        redical_in = torch.cuda.LongTensor(component_in[1])
        com1_in = torch.cuda.LongTensor(component_in[2])
        com2_in = torch.cuda.LongTensor(component_in[3])
        offset1 = torch.cuda.LongTensor(offset[0])
        offset2 = torch.cuda.LongTensor(offset[1])
        offset3 = torch.cuda.LongTensor(offset[2])
        offset4 = torch.cuda.LongTensor(offset[3])
        attention = torch.softmax(self.atten_layers(word_in),dim=-1).unsqueeze(1)
        emb_uword = self.word_embeddings(word_in)
        emb_char = self.u_embeddings(char_in,offset1)
        emb_redical = self.u_embeddings(redical_in,offset2)
        emb_com1 = self.u_embeddings(com1_in,offset3)
        emb_com2 = self.u_embeddings(com2_in,offset4)
        emb_all = torch.stack((emb_uword,emb_char,emb_redical,emb_com1,emb_com2),1)
        emb_vword = self.v_embeddings(word_out)
        emb_mixin = torch.bmm(attention,emb_all).squeeze(1)
        score = torch.mul(emb_mixin, emb_vword)
        score = torch.sum(score, dim=-1)
        score = self.m(score)
        return score

if __name__ == '__main__':

    model = SkipGramModel(364, 180, 100).cuda()
    optimizer = optim.SGD(model.parameters(), lr=0.025)
    Lossfunc = nn.BCELoss(reduction='sum')
    for _ in range(100):
        word_in = torch.cuda.LongTensor([2]*128)
        word_out = torch.cuda.LongTensor([2]*128)
        label = torch.cuda.FloatTensor([1]*128)
        component_in = [[3,5],[2,4,5],[2,3,4],[]]
        offset = [[0]*127+[1],[0]*127+[1],[0]*128,[0]*128]
        outs = model.forward(word_in, component_in, word_out, offset)
        loss = Lossfunc(outs, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

(Linjie Xu) #9

I tried to python setup.py install to reshow your problem. But I got No module named 'tools.setup_helpers'
Then I just use my pytorch ( 0.4.0 on win10) and it occurs an alike problem.
RuntimeError: cuda runtime error (9) : invalid configuration argument at C:/Users/Administrator/Downloads/new-builder/win-wheel/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:281

Due to my version is old, this issue should not be caused by the fix of last issue.