When I use a long tensor to index embedding, it loses its grad_fn:
embed_option = nn.Embedding(4, 10)
vec = embed_option(torch.LongTensor([[0,1,2]]))
vec.requires_grad
is Flase
.
When I tried in main scope, everything worked fine. However, if I define this in a nn.Module class, it loses its gradient function. Any idea how this could happen?