Double backward with checkpointed model and learnable inner loop learning rates

Hi! I’m trying to do MAML + gradient checkpointing + learnable learning rates.

My question is why does multiplying the inner loop gradient with another tensor requiring grad break the outer loop backward pass when gradient checkpointing is enabled?

Right now, MAML + checkpointing and MAML + learnable lr both work, but all 3 together do not. Note that I mean checkpointing at each layer within the model itself, not checkpointing between MAML inner loop steps; e.g. my model is:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.modules = nn.ModuleList([...])

    def forward(self, x):
        for m in self.modules:
            x = torch.util.checkpoint(m, x)
        return x

This is a bit tedious because autograd and checkpointing are incompatible, but I think I’ve got MAML working with only autograd.backward. I’ve done some tests against higher, and I think it’s correct.

This is the code for my inner loop; the rest of my implementation is below.

def mml_inner(model: nn.Module, loss: torch.tensor, inner_lrs: List[torch.nn.Parameter] = None, 
              inner_lr: float = 1e-2, scaler: torch.cuda.amp.GradScaler = None):
    if scaler is not None:
        loss = scaler.scale(loss)

    # Create graph for double backward in outer loop
    loss.backward(create_graph=not (model.mml_first_order or model.mml_eval))

    # We need to track these so we can delete their grads later to prevent a retain 
    #  cycle from backward(create_graph=True)
    model.mml_ctx.extend(list(model.parameters()))

    grad_scale_inv = 1./scaler.get_scale() if scaler is not None else 1.
    with autocast(enabled=scaler is not None):
        for idx, m in enumerate(model.mml_modules):
            m_params = vars(m)['_parameters']
            for name, param in m_params.items():
                if param is None:
                    continue

                # if inner_lrs (learnable learning rates) aren't passed, default to scalar inner_lr
                lr = F.softplus(inner_lrs[idx]) if inner_lrs is not None else inner_lr
                update = lr * grad_scale_inv * param.grad # <--- Anomaly detection picks out this line
                new_param = param - update
                if not model.mml_first_order:
                    new_param.retain_grad()
		m_params[name] = new_param

When doing the outer loop backward pass with checkpointing and learnable learning rates, I get an in-place operation Exception:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This exception only occurs when I run with checkpointing and learnable learning rates together; everything is fine with only one or the other. If I pass None for inner_lrs, or simply pass a list of Parameters with requires_grad=False, everything works fine (but of course, we don’t get learnable learning rates).

Maybe this is because I don’t fully understand autograd, but it’s kind of weird to me that the error triggers for an in-place update to the inner loop parameter only when it’s multiplied with another tensor that requires grad. I assume this has to do with how PyTorch tensor versioning works. I’m also frankly a bit confused why this code works at all; why does the gradient get propagated through the fast weights in the inner loop at all with checkpointing, since they aren’t inputs to the checkpointed sub-modules?

For context, here is the rest of my MAML code:

def _mml_set_requires_grad(model, skip_modules, requires_grad):
    for m in model.modules():
        for p in vars(m)['_parameters'].values():
            if p is not None:
                if m not in skip_modules:
                    p.requires_grad = requires_grad
                else:
                    p.requires_grad = True


def mml_ctx(m, inner_modules, first_order: bool = False, grad_clip: float = 1e9, eval_ = False):
    m_ = copy.deepcopy(m)
    m_.mml_m0 = lambda: m
    m_.mml_first_order = first_order
    m_.mml_propagate_grad = False
    m_.mml_grad_propagation_hooks = []
    m_.mml_modules = inner_modules(m_)
    m_.mml_ctx = []
    m_.mml_grad_clip = grad_clip
    m_.mml_eval = eval_

    # This hack turns off grads for non-adaptation parameters to save memory/compute
    _mml_set_requires_grad(m, inner_modules(m), False)
    _mml_set_requires_grad(m_, m_.mml_modules, False)

    for p, p_ in zip(m.parameters(), m_.parameters()):
        if not p_.requires_grad:
            continue

        def create_hook(model, p):
            def hook(grad):
                if model.mml_propagate_grad:
                    if p.grad is None:
                        p.grad = grad.detach()
                    else:
                        p.grad = (p.grad + grad).detach()
            return hook

        m_.mml_grad_propagation_hooks.append(p_.register_hook(create_hook(m_, p)))

    return m_


def mml_flush(model):
    for h in model.mml_grad_propagation_hooks:
        h.remove()

    # We're done with this model (MAML-wise)
    for p in model.mml_ctx:
        del p.grad
    del model.mml_grad_propagation_hooks
    del model.mml_ctx


def mml_outer_backward(model, outer_loss, scaler = None):
    # Unlike in the inner loop, we want to copy grads
    #  back to the original model now
    model.mml_propagate_grad = True

    if scaler is not None:
        outer_loss = scaler.scale(outer_loss)

    outer_loss.backward()

    mml_flush(model)


def mml_outer_step(model, opt, scaler = None):
    # Update the original/pre-adaptation parameters
    if scaler is not None:
        scaler.unscale_(opt)
        grad = torch.nn.utils.clip_grad_norm_(model.mml_m0().parameters(), model.mml_grad_clip)
        scaler.step(opt)
        scaler.update()
    else:
        grad = torch.nn.utils.clip_grad_norm_(model.mml_m0().parameters(), model.mml_grad_clip)
        opt.step()

    opt.zero_grad()

    del model.mml_m0

    return grad

It’s designed to be used as

def _inner_modules(m):
    return list(m.modules())[-40:]


def _run_mml(model, opt, scaler, half, x, x_, first_order):
    for i in range(x.shape[0]):
        inner_model = mml_ctx(model, _inner_modules, first_order=first_order)

        # Inner loop, 2 gradient steps just for testing
        for step in range(2):
            with autocast(enabled=half):
                pred = inner_model(x[i])
                loss = -pred.mean()

            # This adapts the model in-place
            mml_inner(inner_model, loss, scaler=scaler)

        # Outer loop
        with autocast(enabled=half):
            pred_ = inner_model(x_[i])
            loss_ = pred_.pow(2).mean()
        print(f'outer {i}: {loss_}')

        mml_outer_backward(inner_model, loss_, scaler=scaler)
        mml_outer_step(inner_model, opt, scaler=scaler)

Any insight would be greatly appreciated!