Having issue with max_norm parameter of torch.nn.Embedding

Hello,

I am using torch.nn.Embedding to embed my model’s categorical input features. When I use the max_norm parameter and set it equal to 1, I get the following error during training the model:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [59, 20]] is at version 4; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I would appreciate any help on this problem.

I cannot reproduce the issue using torch==1.13.0+cu117 with this code snippet:

emb = nn.Embedding(num_embeddings=10, embedding_dim=100, max_norm=1.)

for _ in range(10):
    x = torch.randint(0, 10, (8,))
    out = emb(x)
    out.mean().backward()

so could you post a minimal, executable code snippet reproducing the issue, please?

Thanks for your response. I am also able to run the code snippet you’ve provided without any issue.
My code is a large one with lots of modularity. I’ll try to see if I can provide a snippet that reproduces the same issue for me.
In the meantime, I wanted to know if the issue I’m facing is related to what is mentioned on the pytorch website page for torch.nn.Embedding as follows:
When max_norm is not None , Embedding’s forward method will modify the weight tensor in-place. Since tensors needed for gradient computations cannot be modified in-place, performing a differentiable operation on Embedding.weight before calling Embedding’s forward method requires cloning Embedding.weight when max_norm is not None .

Yes, it could be related, but I would also assume you would have to use the .weight parameter explicitly. Is this the case? I.e. are you using embedding.weight directly in your code?

I am not using embedding.weight directly. I define the optimizer this way:
opt = torch.optim.Adam( params=model.parameters() )
I appreciate it if you can provide me with some details on how to use embedding.weight directly.

The docs show the usage which could yield to the error:

# copied from the docs
n, d, m = 3, 5, 7
embedding = nn.Embedding(n, d, max_norm=True)
W = torch.randn((m, d), requires_grad=True)
idx = torch.tensor([1, 2])
a = embedding.weight.clone() @ W.t()  # weight must be cloned for this to be differentiable
b = embedding(idx) @ W.t()  # modifies weight in-place
out = (a.unsqueeze(0) + b.unsqueeze(1))
loss = out.sigmoid().prod()
loss.backward()

# now remove the clone and it will fail
n, d, m = 3, 5, 7
embedding = nn.Embedding(n, d, max_norm=True)
W = torch.randn((m, d), requires_grad=True)
idx = torch.tensor([1, 2])
a = embedding.weight @ W.t()  # weight must be cloned for this to be differentiable
b = embedding(idx) @ W.t()  # modifies weight in-place
out = (a.unsqueeze(0) + b.unsqueeze(1))
loss = out.sigmoid().prod()
loss.backward()
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 5]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

so I was just referring to this given example.

Thanks for your answer. I can’t easily understand this example from the docs. What is the purpose of having both ‘a’ and ‘b’ and why ‘out’ is defined as, out = (a.unsqueeze(0) + b.unsqueeze(1))?

Do we need to first clone the entire embedding tensor as in ‘a’, and then finding the embeddings for our desired indices as in ‘b’? Then how do ‘a’ and ‘b’ need to be added?

In my code, I don’t have W explicitly, I am assuming that W is representative of the weights applied by the torch.nn.Linear layers. So, I just need to prepare the input (which includes the embeddings for categorical features) that goes into my network.

I greatly appreciate any instructions on this, as understanding this example would help me adapt my code accordingly.

The actual "logic"of the example shouldn’t be important, as it only demonstrates the needed .clone() call when you are using the embedding.weight directly in any operation which needs the original value of .weight for gradient calculation:

embedding.weight.clone() @ W.t()

The value of a, b, etc. is only given to reproduce the error which a user might be running into in a real example.