Hi all,
I have a piece of code that performs a transformation on the input tokens of a network that looks like this:
class MyModule(nn.Module):
def __init__(self, val, max=15000, step=50, device='cpu'):
super(MyModule, self).__init__()
self.device = device
self.tt = torch.arange(0, max, step=step, device=device, requires_grad=False).unsqueeze(0).long()
self.val = val
def forward(self, x, st):
st = st.unsqueeze(-1).repeat(1, 1, self.tt.size(1)).long()
mask = (st <= self.tt).long()
out = x.clone().detach() # this doesn't work
out = out.unsqueeze(-1).repeat(1, 1, self.tt.size(1))
out = (out * mask).sum(1)
out[out == 0] = self.val
return out, mask.long()
The problem is that when backward is called I get the following error:
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: leaf variable has been moved into the graph interior
From what I have read it is related to in place modifications of variables, but I can’t seem to find a solution.
Any suggestions?
Thanks