Higher and pytorch lightning MAML Implementation

I am having some trouble implementing MAML using pytorch lightning and higher.

To make things simple, I am starting with only 1 task, using the CIFAR-10 dataset, and a simple Neural Network.

The issue is that higher is giving me this error (In the line I use higher.innerloop_ctx), which I believe is related to the way I defined my classes:

TypeError: Cannot create a consistent method resolution
order (MRO) for bases Module, ABC

I was wondering if someone could clarify what my mistake is. I am finding it hard to figure this problem out since I haven’t found helpful documentation and examples involving both higher and lightning.

Sorry for the verbose, but for the sake of reproducibility, here is my full code:

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader, random_split
from pytorch_lightning.loggers import WandbLogger
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.nn as nn
import higher
import torch
import wandb


class MyClassifier(nn.Module):
    def __init__(self):
        super(MyClassifier, self).__init__()
        self.accuracy = pl.metrics.Accuracy()
        self.input = nn.Linear(in_features=28*28, out_features=15*15)
        self.drop1 = nn.Dropout(0.05)
        self.fc1 = nn.Linear(in_features=15*15, out_features=10*10)
        self.drop2 = nn.Dropout(0.05)
        self.out = nn.Linear(in_features=10*10, out_features=10)
        
    def forward(self, x):
        self.batch_size = x.shape[0]
        x = self.input(x.reshape(self.batch_size, -1))
        x = self.drop1(F.relu(self.fc1(x)))
        x = self.out(x)
        return x

    
class MetaLearner(pl.LightningModule):
    def __init__(self, model, model_lr, learner_lr):
        super().__init__()
        self.model = model
        self.model_lr = model_lr
        self.learner_lr = learner_lr
        self.accuracy = pl.metrics.Accuracy()
        
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        meta_optim = torch.optim.Adam(self.parameters(), lr=self.learner_lr)
        model_optim = torch.optim.SGD(self.parameters(), lr=self.model_lr)
        return [meta_optim, model_optim]
    
    def task_loss(self, y_hat, y):
        return (1 / y.shape[0]) * F.cross_entropy(y_hat, y)
        
    def training_step(self, batch, batch_idx, optimizer_idx):
        # Optimizers
        meta_optim, model_optim = self.optimizers()
        
        # Train and test data
        # TODO: actually split data between train and test
        X_train, Y_train = batch
        X_test, Y_test = batch
        
        meta_loss = torch.tensor(0.0, device=self.device)
        with higher.innerloop_ctx(self.model, model_optim, copy_initial_weights=False) as (fmodel, diff_opt):
            # Update parameters for task on train inputs
            y_hat = self(x_train)
            model_loss = self.task_loss(y_hat, y_train)
            diff_opt.step(model_loss)

            # Update parameters for task on test inputs
            y_hat = self(x_test)
            task_loss = self.task_loss(y_hat, y_test)
            meta_loss += task_loss
            self.log(f"meta_loss_task0", task_loss)
            self.log(f"meta_acc_task0", self.accuracy(y_hat, y_test))
            
        # Update meta-parameters
        self.manual_backward(meta_loss, meta_optim)
        
        return meta_loss
    
    def validation_step(self, batch, _):
        x, y = batch
        y_hat = self(x)
        loss = self.task_loss(y_hat, y)
        self.log("val_loss", loss, on_epoch=True)
        self.log("val_acc", self.accuracy(y_hat, y), on_epoch=True)
        return loss
        
        
class MNISTDataModule(pl.LightningDataModule):
    
    def __init__(self, data_dir="./Data", batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
        
    def setup(self, stage=None):
        if stage == "fit" or stage == None:
            mnist = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist, [50000, 10000])
        if stage == "test" or stage == None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
            
    def train_dataloader(self):
        return DataLoader(self.mnist_train, self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, self.batch_size)
    
    
# Weights and Biases Logger
wandb_logger = WandbLogger(project="Meta-Learning")

# Early Stopping Callback
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

# Create model
model = MyClassifier()

# Create DataModule
mnist = MNISTDataModule()
mnist.prepare_data()
mnist.setup()

# Create Learner
learner = MetaLearner(model, 1e-3, 1e-3)

# Create pytorch trainer
trainer = pl.Trainer(logger=wandb_logger, deterministic=True, callbacks=[early_stop_callback], 
                     profiler=True, automatic_optimization=False)

# Train the model
trainer.fit(learner, datamodule=mnist)

Unfortunately, I don’t know the solution to your issue, but the error might be generally raised, if you try to derive a class from multiple parent classes and e.g. use the wrong order.
I cannot see any obvious code snippets in your post, which would fit this, but this information might be useful nevertheless for debugging.

1 Like