I want to train my transformer model with MAML gradient updates. I’ve found that higher
is an efficient way to go about it. I’m open to other methods also.
So the MetaDataset
wraps any GlueDataset
to give a list containing all classes when meta_dataset[0]
is called. So this will become, num_of_classes (N)
way K shot example.
I’ve written this, which extends HF Trainer
for MAML.
def train(self):
self.create_optimizer_and_scheduler(
int(
len(self.train_dataloader)
// self.args.gradient_accumulation_steps
* self.args.num_train_epochs
)
)
logger.info("***** Running training *****")
self.global_step = 0
self.epoch = 0
eval_step = [2 ** i for i in range(1, 20)]
inner_optimizer = torch.optim.SGD(
self.model.parameters(), lr=self.args.step_size
)
self.model.train()
tqdm_iterator = tqdm(self.train_dataloader, desc="Batch Index")
# n_inner_iter = 5
self.optimizer.zero_grad()
query_dataloader = iter(self.train_dataloader)
for batch_idx, meta_batch in enumerate(tqdm_iterator):
target_batch = next(query_dataloader)
outer_loss = 0.0
# Loop through all classes
for inputs, target_inputs in zip(meta_batch, target_batch):
for k, v in inputs.items():
inputs[k] = v.to(self.args.device)
target_inputs[k] = v.to(self.args.device)
with higher.innerloop_ctx(
self.model, inner_optimizer, copy_initial_weights=False
) as (fmodel, diffopt):
inner_loss = fmodel(**inputs)[0]
diffopt.step(inner_loss)
outer_loss += fmodel(**target_inputs)[0]
self.global_step += 1
self.optimizer.step()
outer_loss.backward()
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.max_grad_norm
)
# Run evaluation on task list
if self.global_step in eval_step:
output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
self.log(output.metrics)
output_dir = os.path.join(
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
)
self.save_model(output_dir)
The above code doesn’t seem to work properly as in the accuracy doesn’t improve. Any directions/tips would be appreciated.