Einsum problem in Pytorch 0.4

I encounter a problem when using the new torch.einsum in Pytorch 0.4.

What I want to do is simply using torch.einsum to implement dot-product attention as often used in seq2seq models. The following is a snippet I create.

import torch

bsz, clen, qlen = 32, 20, 20
d_hid = 64

context = torch.rand(clen, bsz, d_hid, requires_grad=True)
query = torch.rand(qlen, bsz, d_hid, requires_grad=True)
key = val = context

att_score = torch.einsum('cbd,qbd->cqb', (key, query)) 
att_prob = torch.nn.functional.softmax(att_score, dim=0)
att_vec = torch.einsum('cbd,cqb->qbd', (val, att_prob))

att_vec.mean().backward()

However, the backward call raises the following runtime error

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.

I’m not quite sure what’s going on here. Any idea?

1 Like

That warning is usually for when a tensor required for gradient computation is modified in-place. However, I don’t see any in-place modifications here, unless einsum does some internally…

Have you solved it?
If I replace torch.einsum with att_vec = (val.unsqueeze(1)*att_prob.unsqueeze(-1)).sum(dim=0), will obtain the same error.

what about cloning the input variable, is it a correct way to solve this problem?

UPDATE: latest version of PyTorch has fixed this bug. For 0.4, I have seen someone tries to fix it by clone the variables