I have a piece of code that performs a transformation on the input tokens of a network that looks like this:
class MyModule(nn.Module): def __init__(self, val, max=15000, step=50, device='cpu'): super(MyModule, self).__init__() self.device = device self.tt = torch.arange(0, max, step=step, device=device, requires_grad=False).unsqueeze(0).long() self.val = val def forward(self, x, st): st = st.unsqueeze(-1).repeat(1, 1, self.tt.size(1)).long() mask = (st <= self.tt).long() out = x.clone().detach() # this doesn't work out = out.unsqueeze(-1).repeat(1, 1, self.tt.size(1)) out = (out * mask).sum(1) out[out == 0] = self.val return out, mask.long()
The problem is that when backward is called I get the following error:
torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError: leaf variable has been moved into the graph interior
From what I have read it is related to in place modifications of variables, but I can’t seem to find a solution.