Computing weights from a set of independent parameters

I have a Neural network with a single linear layer which has aN^2 weights of these aN^2 weights only aN are independent. In my network class I thus define my independent weights as nn.Parameters and with these I can compute the other weights (I have constructed a custom function that does this). So, in one training iteration only these independent parameters get updated by calling loss.backward() followed by optim.step(), which is what I want, however I want to update all my weights automatically when these independent parameters get updated. I tried something like this from within my network class:

@property
def weights(self):
return self.__weights

@weights.setter
def weights(self, new_value):
    self.__weights = new_value
    self.set_forward_weights()

Here set_forward_weights() update all weights using the independent parameters.
This does not work. Could someone help me with an elegant way of solving this? Or do I need to call set_forward_weights() every time I do optim.step() (everything works fine when I do this, but I think there is a better solution possible).

Every time you call optim.step() you update your Parameters. if your weights depend on the parameters, I’d expect you need to call it every time?

Yes, sorry for the confusion. I do need to update it every time of course, however I want some function that does this automatically. So, if these independent parameters get changed for whatever reason I want to automatically update all my weights without manually calling the function set_forward_weights(), i.e. the moment the nn.Parameters are changed I want the function set_forward_weights() to be callled. The exemplary script I gave should do this, but it does not work for nn.Parameters (or I can’t get it to work for parameters). Does this clarify my problem?

You can try using backward hooks. If I understand correctly, this is what you’re attempting?

class MWE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.ones(2,2)
        self.fc = nn.Linear(2,2)

    def upd_wt(self, x):
        self.weight += x
    
    def forward(self, x):
        return self.fc(x)

def grad_upd(module, grad_inputs, grad_outputs):
    module.upd_wt(grad_outputs[0])

net = MWE()
h = net.register_full_backward_hook(grad_upd)
net(torch.rand(2,2)).pow(2).sum().backward()

print(net.weight) # should be the updated values

I will post part of my script tomorrow to clarify. In general, this won’t work for me, but maybe with a few adaptations it will. Thanks, and I will get back to you tomorrow.

This might be due to my lag of experience with PyTorch, but I can’t get it to work with a hook.
Here is a snippet of my code:

class RBM_SYMMETRIC(nn.Module):
    def __init__(self, L, a):
        super(RBM_SYMMETRIC, self).__init__()
        self.N = L * L
        self.weights = nn.Parameter(torch.FloatTensor(a * self.N).uniform_(-0.01, 0.01))
        self.biases = nn.Parameter(torch.FloatTensor(a * self.N).uniform_(-0.01, 0.01))
        self.set_forward_weights()

    def set_forward_weights(self):
        "lot of stuff I cant share happening here, but it creates a torch tensor 
         of shape (N, aN)  from self.weights which has aN parameters. "
         self.forward_weights = "torch.tensor with shape N, aN constructed from aN 
         independed weights given by self.weights"

    def forward(self, state):
        ws = state @ self.forward_weights + self.biases
        y = "some_activation_function"(ws)
        return y

So, whenever I update my network during training only self.weights and self. biases get updated, which is fine since they are the only independet parameters. However, when this update happens I want to automatically call set_forward_weights(). Currently I am calling it manually every time after update, i.e. every time I call optim.step() I call net.set_forward_weights() right after (this works fine).

Can this work with a hook and if so how?
Or is there another way to do this?

One important thing to note is that I have aN independent weights not aN^2 still I am using aN^2 weights in the forward pass, I don’t want there to be aN^2 nn.parameters/ independent weights, this is because I use a special update rule which scales exponentially with the number of independent variables.

I am afraid this is as clear as I can make it.

I think one solution might be to only register weights and biases in the optim, and not register the forward_weights. This way when you call optim.step, it will only update your “independent” variables, and your “dependent variables” will be updated by your function.

Note that if you want to learn weights you still need gradients of your forward_weights (because of autograd’s chain rule).

@Dominique_Kosters I wonder if you are aware of parametrization? You might be able to use that here - Parametrizations Tutorial — PyTorch Tutorials 1.10.1+cu102 documentation

1 Like

Sorry, I didn’t answer sooner, I was unable to work on the project. Parametrization seems to be exactly what I want. When using parametrization everything works and the code does look cleaner, however, it is a factor 10 slower than what I had originally. (Originally I construct a set of independent parameters, which has a smaller dimension than the full sets of weights, and construct my weights from it) If I understand correctly parameterization firstly construct the normal linear layer with its weights and biases and then adjust the weights according to the self imposed symmetry rule. I think the decrease in performance arises because the original weights and biases are still being tracked. The reason for using symmetries is in part increases the computation time, so although parameterization is exactly where I asked for it is sadly not the solution. Thanks for your help, if you think of another solution please let me know.

I fixed the issue, every time the linear layer is called it does the parametrization. But, it should only do the parametrizations when it is used for the first time and when it is updated. This can be achieved with “with parametrize.cached():” This solved the long computation times. I will mark your answer as a solution and thanks again.