Solve leaf variable in Custom Embedding Pytorch C++

I have written a custom embedding layer

struct CustomEmbeddingImpl : torch::nn::Module
{
CustomEmbeddingImpl(int vocab_size, int embedding_size)
{
embedding = register_module(“embedding”, torch::nn::Embedding(vocab_size + 1, embedding_size));
}

torch::Tensor forward(torch::Tensor indices, torch::Tensor weights)
{
    assert((indices.sizes()[0] == weights.sizes()[0]) && (indices.sizes()[1] == weights.sizes()[1]));

    torch::Tensor x = torch::transpose(embedding->forward(indices.to(torch::kInt64)), 1, 2);
    torch::Tensor y = torch::unsqueeze(weights, 2);

    torch::Tensor z = torch::matmul(x, y).squeeze(2);
    return z;
}

void initialize()  // init randomly
{
    torch::Tensor& weight = embedding->named_parameters()["weight"];
    torch::nn::init::normal_(weight, 0, 1);
    for(int j = 0; j < weight.sizes()[1]; j++)
        weight[0][j] = 0.0;
}

void initialize(vector<vector<float> >& pretrained_embs)  // init with pretrained
{
    torch::Tensor weight = embedding->named_parameters()["weight"];

    assert((pretrained_embs.size() == weight.sizes()[0] - 1) && (pretrained_embs[0].size() == weight.sizes()[1]));

    for(int i = 0; i < pretrained_embs.size(); i++)
    {
        for(int j = 0; j < weight.sizes()[1]; j++)
            embedding->named_parameters()["weight"][i + 1][j] = pretrained_embs[i][j];
    }
    for(int j = 0; j < weight.sizes()[1]; j++)
        embedding->named_parameters()["weight"][0][j] = 0.0;
}
torch::nn::Embedding embedding{nullptr};

};
TORCH_MODULE(CustomEmbedding);

Since, I am modifying the embedding layer weight Parameter, when I try to employ autograd here, I get an error that the leaf variable has been moved to the graph interior. How do I solve this issue, while being able to do both kind of initialization(and in both cases, initializing the 0(=padding_idx) with 0s). Thanks.

Whenever you modify the weights for initialization, you want to make sure to disable the autograd while you’re doing these modifications. You can do so with torch::NoGradGuard guard;.

Thanks a lot @albanD.