Implicit Gradients for Meta Parameters returns None

I am attempting to implement the implicit gradients algorithm [1, 2, 3] to optimize some meta-parameters (in my case the parameters of a loss function). However, the (meta-)gradients produced are always None. Can I have some help identifying what the problem is, and how I can resolve this issue?

Below I have attached some simplified code that reproduces the error.

from sklearn.datasets import make_regression
import torch

# Creating a meta-network for representing the loss function.
class MetaNetwork(torch.nn.Module):

    def __init__(self):
        super(MetaNetwork, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(2, 10),
            torch.nn.Linear(10, 1),

    def forward(self, y_pred, y_target):
        return self.model(, y_target), dim=1)).mean()

# Creating a base-network for learning the model of the data.
class BaseNetwork(torch.nn.Module):

    def __init__(self):
        super(BaseNetwork, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(1, 10),
            torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.model(x)

# Generating some synthetic training and validation data.
X_train, y_train = make_regression(n_samples=100, n_features=1, n_informative=1, noise=0.1, random_state=1)
X_valid, y_valid = make_regression(n_samples=100, n_features=1, n_informative=1, noise=0.1, random_state=2)

# Converting data into the correct format.
X_train, y_train = torch.tensor(X_train).float(), torch.unsqueeze(torch.tensor(y_train).float(), 1)
X_valid, y_valid = torch.tensor(X_valid).float(), torch.unsqueeze(torch.tensor(y_valid).float(), 1)

# Creating our base and meta models, as well as the base optimizer.
meta_network, base_network = MetaNetwork(), BaseNetwork()
base_optimizer = torch.optim.SGD(base_network.parameters(), lr=0.01)

# Training the model using the meta-network as the loss function.
for i in range(10):
    yp = base_network(X_train)
    base_loss = meta_network(yp, y_train)

meta_loss_fn = torch.nn.MSELoss()

# Computing the training and validation (meta) loss.
train_loss = meta_loss_fn(base_network(X_train), y_train)
validation_loss = meta_loss_fn(base_network(X_valid), y_valid)

# Gradient of the validation loss with respect to the base model weights.
dloss_val_dparams = torch.autograd.grad(validation_loss, base_network.parameters(),
                                        retain_graph=True, allow_unused=True)

# Gradient of the training loss with respect to the base model weights.
dloss_train_dparams = torch.autograd.grad(train_loss, base_network.parameters(),
                                          create_graph=True, allow_unused=True)

p = v = dloss_val_dparams

for _ in range(10):
    grad = torch.autograd.grad(dloss_train_dparams, base_network.parameters(),
                               grad_outputs=v, retain_graph=True, allow_unused=True)

    grad = [g * 0.01 for g in grad]

    v = [curr_v - curr_g for (curr_v, curr_g) in zip(v, grad)]
    p = [curr_p + curr_v for (curr_p, curr_v) in zip(p, v)]

v2 = list(0.01 * pp for pp in p)
v3 = torch.autograd.grad(dloss_train_dparams, meta_network.parameters(), grad_outputs=v2, allow_unused=True)

print("Meta Gradient", v3)