I created an embedding and passing out of index tensors doesn’t work on CPU, but on GPU it is returning a tensor(all except first embedding keep changing with every call, and even the first tensor is not equal to embedding.weight).
Torch version = 1.5.0
a = torch.nn.Embedding(1,768)
>>> a(torch.LongTensor([1,2,3]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<path>/envs/vl-bert/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "<path>/envs/vl-bert/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 114, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "<path>/vl-bert/lib/python3.6/site-packages/torch/nn/functional.py", line 1724, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
a.cuda()(torch.LongTensor([1,2,3]).cuda())
tensor([[1.4013e-45, 0.0000e+00, 2.8026e-45, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00]], device='cuda:0', grad_fn=<EmbeddingBackward>)
a.cuda()(torch.LongTensor([1,2,3]).cuda())
tensor([[ 1.4013e-45, 0.0000e+00, 2.8026e-45, ..., 1.0869e+00,
-1.7131e+00, -6.9908e-01],
[-5.6997e-01, 1.6486e+00, 1.7096e+00, ..., 1.0869e+00,
1.7131e+00, 6.9908e-01],
[ 5.6996e-01, 1.6486e+00, 1.7096e+00, ..., 1.7427e+00,
2.0000e+00, 1.7774e+00]], device='cuda:0',
grad_fn=<EmbeddingBackward>)
a.cuda()(torch.LongTensor([1,2,3]).cuda())
tensor([[1.4013e-45, 0.0000e+00, 2.8026e-45, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 2.0000e+00, 0.0000e+00, ..., 1.0869e+00, 1.7131e+00,
6.9908e-01],
[6.5202e+06, 4.3207e+00, 8.6057e-02, ..., 1.7427e+00, 2.0000e+00,
1.7774e+00]], device='cuda:0', grad_fn=<EmbeddingBackward>)
>>> a.weight
Parameter containing:
tensor([[ 0.7804, 1.5051, 0.0861, -0.9269, -0.8105, -2.7018, -1.2860, -0.4517,
0.6019, 1.2832, 2.1942, 0.3216, 1.9599, 0.8146, 0.0085, 0.6976,
1.9618, 0.0783, 1.3515, 0.8830, 0.8101, -2.4665, 2.6164, 1.1543,
-0.8128, -0.9217, 1.3534, -0.3387, 0.1712, 1.1185, -0.5681, 0.2406,
1.8387, 0.7704, 1.6712, 0.4060, -1.2792, -0.3