I am having some trouble implementing MAML using pytorch lightning and higher.
To make things simple, I am starting with only 1 task, using the CIFAR-10 dataset, and a simple Neural Network.
The issue is that higher is giving me this error (In the line I use higher.innerloop_ctx), which I believe is related to the way I defined my classes:
TypeError: Cannot create a consistent method resolution
order (MRO) for bases Module, ABC
I was wondering if someone could clarify what my mistake is. I am finding it hard to figure this problem out since I haven’t found helpful documentation and examples involving both higher and lightning.
Sorry for the verbose, but for the sake of reproducibility, here is my full code:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader, random_split
from pytorch_lightning.loggers import WandbLogger
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.nn as nn
import higher
import torch
import wandb
class MyClassifier(nn.Module):
def __init__(self):
super(MyClassifier, self).__init__()
self.accuracy = pl.metrics.Accuracy()
self.input = nn.Linear(in_features=28*28, out_features=15*15)
self.drop1 = nn.Dropout(0.05)
self.fc1 = nn.Linear(in_features=15*15, out_features=10*10)
self.drop2 = nn.Dropout(0.05)
self.out = nn.Linear(in_features=10*10, out_features=10)
def forward(self, x):
self.batch_size = x.shape[0]
x = self.input(x.reshape(self.batch_size, -1))
x = self.drop1(F.relu(self.fc1(x)))
x = self.out(x)
return x
class MetaLearner(pl.LightningModule):
def __init__(self, model, model_lr, learner_lr):
super().__init__()
self.model = model
self.model_lr = model_lr
self.learner_lr = learner_lr
self.accuracy = pl.metrics.Accuracy()
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
meta_optim = torch.optim.Adam(self.parameters(), lr=self.learner_lr)
model_optim = torch.optim.SGD(self.parameters(), lr=self.model_lr)
return [meta_optim, model_optim]
def task_loss(self, y_hat, y):
return (1 / y.shape[0]) * F.cross_entropy(y_hat, y)
def training_step(self, batch, batch_idx, optimizer_idx):
# Optimizers
meta_optim, model_optim = self.optimizers()
# Train and test data
# TODO: actually split data between train and test
X_train, Y_train = batch
X_test, Y_test = batch
meta_loss = torch.tensor(0.0, device=self.device)
with higher.innerloop_ctx(self.model, model_optim, copy_initial_weights=False) as (fmodel, diff_opt):
# Update parameters for task on train inputs
y_hat = self(x_train)
model_loss = self.task_loss(y_hat, y_train)
diff_opt.step(model_loss)
# Update parameters for task on test inputs
y_hat = self(x_test)
task_loss = self.task_loss(y_hat, y_test)
meta_loss += task_loss
self.log(f"meta_loss_task0", task_loss)
self.log(f"meta_acc_task0", self.accuracy(y_hat, y_test))
# Update meta-parameters
self.manual_backward(meta_loss, meta_optim)
return meta_loss
def validation_step(self, batch, _):
x, y = batch
y_hat = self(x)
loss = self.task_loss(y_hat, y)
self.log("val_loss", loss, on_epoch=True)
self.log("val_acc", self.accuracy(y_hat, y), on_epoch=True)
return loss
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir="./Data", batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == "fit" or stage == None:
mnist = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist, [50000, 10000])
if stage == "test" or stage == None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, self.batch_size)
# Weights and Biases Logger
wandb_logger = WandbLogger(project="Meta-Learning")
# Early Stopping Callback
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=False,
mode='min'
)
# Create model
model = MyClassifier()
# Create DataModule
mnist = MNISTDataModule()
mnist.prepare_data()
mnist.setup()
# Create Learner
learner = MetaLearner(model, 1e-3, 1e-3)
# Create pytorch trainer
trainer = pl.Trainer(logger=wandb_logger, deterministic=True, callbacks=[early_stop_callback],
profiler=True, automatic_optimization=False)
# Train the model
trainer.fit(learner, datamodule=mnist)