Custom layer giving problems when updating class parameter

Hi everyone. I’m trying to implement a custom layer like this:

import torch.nn as nn
import math
import torch

class LinLayer(nn.Module):
    """ Custom Linear layer but mimics a standard linear layer """
    def __init__(self, size_in, size_out):
        super().__init__()
        self.size_in, self.size_out = size_in, size_out
        weights = torch.Tensor(size_out, size_in)
        self.weights = nn.Parameter(weights)  # nn.Parameter is a Tensor that's a module parameter.
        bias = torch.Tensor(size_out)
        self.bias = nn.Parameter(bias)

        self.hebb = torch.zeros((size_out, size_in), requires_grad=False)

        # initialize weights and biases
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)  # bias init

    def forward(self, x):
        hebb = self.hebb

        w_times_x = torch.mm(x, self.weights.t() + hebb.t())

        yout = torch.add(w_times_x, self.bias)  # w times x + b

        self.hebb = torch.matmul(yout.t(), x)

        return yout

the problem is that when I try to update self.hebb = torch.matmul(yout.t(), x) at the end I get a “RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed).”. I have already tried detaching the self.hebb tensor but it make my loss a nan. How can I get this code to work?

This sounds like the right approach since you are currently storing the computation graph in self.hebb. Since you are seeing a NaN loss after this fix you should narrow down where this invalid values comes from (e.g. in the forward pass, the gradients, or somewhere else).

Thx @ptrblck! Indeed, the problem was coming from somewhere else (self.hebb was simply growing too large and generating nans at some point).

I would have a further question tho: in this framework of differential plasticity only these self.alpha and self.weight are optimized via backpropagation (so there’s no partial derivative of the loss wrt to self.hebb). Do I understand correctly that calling self.hebb.detach_() means that the partial derivative of the loss wrt self.hebb won’t be computed, but it won’t remove the effect of self.hebb when applying the chain rule to update self.weight and self.alpha?

This is the forward function (self.alphas is one of the parameters of the model):

def forward(self, x):

        torch.clamp(self.alphas, min = -1.0, max = 1.0)

        w_times_x = torch.mm(x, self.weights.t() + (torch.mul(self.alphas, self.hebb)).t())

        yout = torch.add(w_times_x, self.bias)

        self.hebb = 0.1 * torch.matmul(yout.t(), x)

        self.hebb.detach_()

        return yout

Detaching self.hebb will avoid creating an increasing computation graph since otherwise you would backpropagate through the forward pass of the current iteration:

        hebb = self.hebb
        w_times_x = torch.mm(x, self.weights.t() + hebb.t())
        yout = torch.add(w_times_x, self.bias)  # w times x + b
        self.hebb = torch.matmul(yout.t(), x)

as well as the previous one(s).