Embeddings with half precision

Is it posible to use nn.Embedding with half precision models? The half precision code below crashes in
backward?

# Long 
embedding = nn.Embedding(embedding_dim=5, num_embeddings=10)
embedding.cuda()

x_device = torch.LongTensor([1,2,0,1]).cuda()
xv = Variable(x_device)
o = embedding(xv)
t = torch.zeros(o.size()).cuda()
o.backward(t)

# Half Crahes
embedding = nn.Embedding(embedding_dim=5, num_embeddings=10)
embedding.cuda().half()

x_device = torch.LongTensor([1,2,0,1]).cuda()
xv = Variable(x_device)
o = embedding(xv)
print(o)
t = torch.zeros(o.size()).cuda().half()
o.backward(t)

I think there is a bug in the pytorch code for cuda.
I think this line should be instead

if grad_output.is_cuda:
1 Like

Thanks. I’ll try to understand the Embedding code better and create a PR.