Autograd doesn't retain the computation graph on a meta-learning algorithm

Hi,

I’m trying to implement the algorithm described here.
A short description of the related part is as follows.

My attempt is as follows.

# 1. Forward-backward pass on training data
_, (inputs, labels) =  next(enumerate(train_loader))
inputs, labels = inputs.to(device=args.device, non_blocking=True),\
                        labels.to(device=args.device, non_blocking=True)
meta_model.load_state_dict(model.state_dict())

y_hat_f = meta_model(inputs)
criterion.reduction = 'none'
l_f = criterion(y_hat_f, labels)
eps = torch.rand(l_f.size(), requires_grad=False, device=args.device).div(1e6)
eps.requires_grad = True
l_f = torch.sum(eps * l_f)

# 2. Compute grads wrt model and update its params
l_f.backward(retain_graph=True)
meta_optimizer.step()

# 3. Forward-backward pass on meta data with updated model
_, (inputs, labels) =  next(enumerate(meta_loader))
inputs, labels = inputs.to(device=args.device, non_blocking=True),\
                        labels.to(device=args.device, non_blocking=True)

y_hat_g = model(inputs)
criterion.reduction = 'mean'
l_g = criterion(y_hat_g, labels)

# 4. Compute grads wrt eps and update weights
eps_grads = torch.autograd.grad(l_g, eps)
.....

At this line:

eps_grads = torch.autograd.grad(l_g, eps)

I get an error saying,

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

If I set allow_used=True, it returns None to eps_grad.

As far as I can tell, autograd loses computation graph for some reason and doesn’t retain the information that eps was used in the computation of l_f, which in turn used in updating parameters of the model. So, l_g should be differentiable wrt eps but it doesn’t work here.

How can I solve this ?

1 Like

Hi,

I think you’re missing a create_graph=True in the first backward if you want to be able to backprop through the backward pass.
Also pytorch’s optimizers are not differentiable (yet). So you won’t be able to backprop through that step either.

I would recommend using the higher package to do this: https://github.com/facebookresearch/higher

2 Likes

so i should update the parameters manually, presumably looping through them and deducting gradients, given optimizers are not differentiable?

Yes,
But the tricky bit is that nn.Parameter() are built to be parameters that you learn. So they cannot have history. So you will have to delete these and replace them with the new updated values as Tensors (and keep them in a different place so that you can still update them with your optimizer).
That is why I recommended the library above that does all that for you :slight_smile:

3 Likes

Gotcha. Thank a lot.

Thanks a lot for pointing me this library. It made things a lot easier.

I’ve a question regarding the performance though. Roughly, the code is 4-5 times slower than what it used to be. Given meta learning algorithm does 3x forward-backward passes, it seems sort of OK to me. Just in case, could you take a quick look and tell me whether I’m doing something grossly inefficient in the following code ? Also, I’ve been able to replicate an experiment from the paper so I believe it is OK from the correctness perspective.

for rnd in tqdm(range(1, args.epochs+1)):
    model.train()
    train_loss, train_acc = 0.0, 0.0 
    
    for _, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device=args.device, non_blocking=True),\
                        labels.to(device=args.device, non_blocking=True)
        opt.zero_grad()
        
        with higher.innerloop_ctx(model, opt) as (meta_model, meta_opt):
            # 1. Update meta model on training data
            meta_train_outputs = meta_model(inputs)
            criterion.reduction = 'none'
            meta_train_loss = criterion(meta_train_outputs, labels)
            eps = torch.rand(meta_train_loss.size(), requires_grad=False, device=args.device).div(1e6)
            eps.requires_grad = True
            meta_train_loss = torch.sum(eps * meta_train_loss)
            meta_opt.step(meta_train_loss)
            
            # 2. Compute grads of eps on meta validation data
            meta_val_outputs = meta_model(meta_inputs)
            criterion.reduction = 'mean'
            meta_val_loss = criterion(meta_val_outputs, meta_labels)
            eps_grads = torch.autograd.grad(meta_val_loss, eps, only_inputs=True)[0].detach()
            
        # 3. Compute weights for current training batch
        w_tilde = torch.clamp(-eps_grads, min=0)
        l1_norm = torch.sum(w_tilde)
        if l1_norm != 0:
            w = w_tilde / l1_norm
        else:
            w = w_tilde
            
        # 4. Train model on weighted batch
        outputs = model(inputs)
        criterion.reduction = 'none'
        minibatch_loss = criterion(outputs, labels)
        minibatch_loss = torch.sum(w * minibatch_loss)
        minibatch_loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
        opt.step()
        
        # keep track of epoch loss/accuracy
        train_loss += minibatch_loss.item()*outputs.shape[0]
        _, pred_labels = torch.max(outputs, 1)
        train_acc += torch.sum(torch.eq(pred_labels.view(-1), labels)).item()
        
    # inference after epoch
    with torch.no_grad():
        train_loss, train_acc = train_loss/len(train_dataset), train_acc/len(train_dataset)       
        val_loss, (val_acc, val_per_class) = utils.get_loss_n_accuracy(model, criterion, val_loader, args)                                  
        scheduler.step(val_loss)
        # log/print data
        #writer.add_scalar('Validation/Loss', val_loss, rnd)
        #writer.add_scalar('Validation/Accuracy', val_acc, rnd)
        #writer.add_scalar('Training/Loss', train_loss, rnd)
        #writer.add_scalar('Training/Accuracy', train_acc, rnd)
        print(f'|Train/Valid Loss: {train_loss:.3f} / {val_loss:.3f}|', end='--')
        print(f'|Train/Valid Acc: {train_acc:.3f} / {val_acc:.3f}|', end='\r')

Just in case, full code:

Your old code was not running any backward. So it was expected that this will be slower :smiley:

The code looks ok to me.

1 Like