I encounter a problem when using the new torch.einsum
in Pytorch 0.4.
What I want to do is simply using torch.einsum
to implement dot-product attention as often used in seq2seq models. The following is a snippet I create.
import torch
bsz, clen, qlen = 32, 20, 20
d_hid = 64
context = torch.rand(clen, bsz, d_hid, requires_grad=True)
query = torch.rand(qlen, bsz, d_hid, requires_grad=True)
key = val = context
att_score = torch.einsum('cbd,qbd->cqb', (key, query))
att_prob = torch.nn.functional.softmax(att_score, dim=0)
att_vec = torch.einsum('cbd,cqb->qbd', (val, att_prob))
att_vec.mean().backward()
However, the backward
call raises the following runtime error
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.
I’m not quite sure what’s going on here. Any idea?