The weight
parameter of nn.Embedding
does have this attribute:
emb = nn.Embedding(10, 10)
x = torch.randint(0, 10, (10,))
out = emb(x)
out.mean().backward()
print(emb.weight.grad)
print(emb.weight.grad.data)
However, note that the usage of .data
is not recommended.