Hi there, I’m new to PyTorch and I’ve been trying to fix a runtime error without success. I have defined a custom layer that has the following forward pass:
def forward(self, x):
w_times_x = torch.mm(x, self.weights.t() + (torch.mul(self.alphas, self.h_ij)).t())
yout = torch.add(w_times_x, self.bias)
self.h_ij = .99 * self.h_ij + .01 * torch.ger(x[0], yout[0]).t()
return yout
Only self.weights and self.alphas are parameters of the model. The tensor self.h_ij is simply accumulating the product of units’ i and j activations over forward passes (for that I’m doing batch size of 1 such that the training samples are presented sequentially so that the updated value of self.h_ij and the end of the forward pass is used during the forward pass of the next training sample).
The training loop is the following:
for epoch in range(num_epochs):
for batch_inputs, batch_outputs in dataloader:
optimizer.zero_grad()
predictions = model(batch_inputs)
loss = loss_function(predictions, batch_outputs)
loss.backward()
optimizer.step()
The problem I’m having is that after the first .backward() call I get a RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). At the end of this error message it’s suggested I use .backward(retain_graph=True). If I do that I get a different error message RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.
What am I doing wrong? I now that if I call self.h_ij.detach_() at the end of the forward method I got get the first error but this will remove (if I understand correctly) self.h_ij from the computational graph, so it won’t be taken in consideration when computing the gradients wrt to self.alpha.
Another way I can avoid the error is by simply “rebuilding” the tensor (like in the code bellow) but I think this must be braking the computational graph somehow.
def rebuild_h_ij(self):
_aux = self.h_ij.tolist().copy()
self.h_ij = torch.zeros((self.size_out, self.size_in), requires_grad = False)
self.h_ij = torch.tensor(_aux, requires_grad = False, dtype = torch.float32)
Some help with this would be rather appreciated. Thx in advance.