I’m trying to train some entity embeddings, but when doing a backward pass I get an error about Leaf Variables being moved into the graph interior. I saw there are other threads where people had that issue and it’s due to in-place operations or assigning to tensors, but I’m not doing anything like that so I’m unsure what’s causing it.
My code is based on the code here; here’s my network:
class GameEmbedding(torch.nn.Module):
def __init__(self, num_games, num_links, embedding_dim=64):
super(GameEmbedding, self).__init__()
self.game_embedding = torch.nn.Embedding(num_games, embedding_dim, max_norm=1.0)
self.link_embedding = torch.nn.Embedding(num_links, embedding_dim, max_norm=1.0)
self.embedding_dim = embedding_dim
def forward(self, batch):
# in the batch each input is [game, link, label]
# label is 1 (true) or -1 (false)
t1 = self.game_embedding(torch.LongTensor([v[0] for v in batch]))
t2 = self.link_embedding(torch.LongTensor([v[1] for v in batch]))
dot_products = torch.bmm(
t1.contiguous().view(len(batch), 1, self.embedding_dim),
t2.contiguous().view(len(batch), self.embedding_dim, 1)
)
return dot_products.contiguous().view(len(batch))
Here’s my training, skipping initialization params:
for i in range(num_epochs):
for j in range(num_steps_per_epoch):
optimizer.zero_grad()
minibatch = build_minibatch(num_positives, num_negatives)
y = model.forward(minibatch)
target = torch.FloatTensor([v[2] for v in minibatch])
loss = loss_function(y, target)
if i == 0 and j == 0:
print('r: loss = %.3f' % float(loss))
loss.backward(retain_graph=True)
optimizer.step()
print('%s: loss = %.3f' % (i, float(loss)))
And here’s detailed error output:
Traceback (most recent call last):
File "train.py", line 88, in <module>
loss.backward(retain_graph=True)
File "/mnt/pool/code/gamesearch/env/lib/python3.7/site-packages/torch/tensor.py", line 150, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/mnt/pool/code/gamesearch/env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: leaf variable has been moved into the graph interior
Sorry I can’t ask a more specific question, but after reading up on other cases of this error and how leaf variables work I have no idea why it’s happening here.