Learnable LR, not getting gradients

Hi everyone,

I am trying to implement MetaSGD, where learning rates are learnable. I am having trouble getting gradients for them.

The code:

        for j in range(iterations):
            self.optim.zero_grad()
            inner_output = self.model(inner_x)
            inner_loss = self.criterion(inner_output, inner_y)
            grads = torch.autograd.grad(inner_loss, self.model.parameters(), create_graph=True)
            for i, param in enumerate(self.model.parameters()):
                param.grad = self.lrs[i] * grads[i]
             self.optim.step()

        outer_output = self.model(outer_x)
        outer_loss = self.criterion(outer_output, outer_y)
        
        # These lines should create the grads for the learning rates and update them accordingly.
        outer_loss.backward()
        # lrs.grad = None
        other_optimizer.step()

What happens is that lrs.grad = None. Meaning that param.grad = self.lrs[i] * grads[i] probably doesnt add to the computation graph for calculating the gradients.

And if I change the code to:

        for j in range(iterations):
            self.optim.zero_grad()
            inner_output = self.model(inner_x)
            inner_loss = self.criterion(inner_output, inner_y)
            grads = torch.autograd.grad(inner_loss, self.model.parameters(), create_graph=True)
            for i, param in enumerate(self.model.parameters()):
                param -= self.lrs[i] * grads[i]

        outer_output = self.model(outer_x)
        outer_loss = self.criterion(outer_output, outer_y)
        
        # These lines should create the grads for the learning rates and update them accordingly.
        outer_loss.backward()
        other_optimizer.step()

Then param -= self.lrs[i] * grads[i] gives a RunTimeError because of an in place operation. Which are not allowed for variables that requires grad.

<edit>
param = param - self.lrs[i] * grads[i], here no learning happens as well because the first param actually a new variable probably. And self.lrs.grad = None as well.
</edit>

Does anyone know a solution, i.e. getting gradients for the learning rates?