Meta learning with Transformers

I want to train my transformer model with MAML gradient updates. I’ve found that higher is an efficient way to go about it. I’m open to other methods also.

So the MetaDataset wraps any GlueDataset to give a list containing all classes when meta_dataset[0] is called. So this will become, num_of_classes (N) way K shot example.

I’ve written this, which extends HF Trainer for MAML.

def train(self):

        self.create_optimizer_and_scheduler(
            int(
                len(self.train_dataloader)
                // self.args.gradient_accumulation_steps
                * self.args.num_train_epochs
            )
        )

        logger.info("***** Running training *****")

        self.global_step = 0
        self.epoch = 0

        eval_step = [2 ** i for i in range(1, 20)]
        inner_optimizer = torch.optim.SGD(
            self.model.parameters(), lr=self.args.step_size
        )
        self.model.train()

        tqdm_iterator = tqdm(self.train_dataloader, desc="Batch Index")

        #  n_inner_iter = 5
        self.optimizer.zero_grad()
        query_dataloader = iter(self.train_dataloader)

        for batch_idx, meta_batch in enumerate(tqdm_iterator):
            target_batch = next(query_dataloader)
            outer_loss = 0.0
            # Loop through all classes
            for inputs, target_inputs in zip(meta_batch, target_batch):

                for k, v in inputs.items():
                    inputs[k] = v.to(self.args.device)
                    target_inputs[k] = v.to(self.args.device)

                with higher.innerloop_ctx(
                    self.model, inner_optimizer, copy_initial_weights=False
                ) as (fmodel, diffopt):

                    inner_loss = fmodel(**inputs)[0]
                    diffopt.step(inner_loss)
                    outer_loss += fmodel(**target_inputs)[0]

            self.global_step += 1
            self.optimizer.step()

            outer_loss.backward()

            if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.args.max_grad_norm
                )

            # Run evaluation on task list
            if self.global_step in eval_step:
                output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
                self.log(output.metrics)

                output_dir = os.path.join(
                    self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
                )
                self.save_model(output_dir)

The above code doesn’t seem to work properly as in the accuracy doesn’t improve. Any directions/tips would be appreciated.