Meta Learning - keep parameter dependence after gradient step

Hi everyone,

I want to implement a meta-learning algorithm and struggle a bit. I have two sets of parameters phi and theta, that are basically the same. phi are the normal parameters that are supposed to minimize some loss and theta are meta-parameters that should regularize phi. Now I have

phi ← phi - gradient(loss(phi) - (phi - theta)^2 )

i.e. I minimize for the normal loss and for keeping phi close to theta.

Now the new phi has a dependence on theta.

I want to repeat that step a couple of times and then take gradient of loss(phi) with respect to theta (the dependence is due to the regularization)

Now the dependence is not retained in the pytorch code that i have below. How can i retain the dependence in order to calculate the gradient with respect to theta?

import torch
from torch import nn
import torch.optim as optim
import higher


class Model(nn.Module):

    def __init__(self, n_in, n_out):
        super().__init__()
        self.x = nn.Linear(n_in, n_out)

    def forward(self, input):
        return self.x(input)


x = Model(2, 1)
y = Model(2, 1)

in_ = torch.randn(2,)

inner_opt = optim.Adam(x.parameters(), lr=0.05)
optimizer_y = optim.Adam(y.parameters(), lr=0.05)

with higher.innerloop_ctx(x, inner_opt, copy_initial_weights=False) as (fnet, diffopt):
    # Optimize the likelihood of the support set by taking
    # gradient steps w.r.t. the model's parameters.
    # This adapts the model's meta-parameters to the task.
    # higher is able to automatically keep copies of
    # your network's parameters as they are being updated.
    for _ in range(2):
        spt_loss = fnet(in_)
        reg_loss = 0
        for p, q in zip(x.parameters(), y.parameters()):
            reg_loss += ((p-q)**2).sum()
        spt_loss += reg_loss
        diffopt.step(spt_loss)

    loss = fnet(in_)
    loss.backward()

optimizer_y.step()
# output here is None
for q in y.parameters():
    print(q.grad)