Hi! I’m trying to do MAML + gradient checkpointing + learnable learning rates.

My question is **why does multiplying the inner loop gradient with another tensor requiring grad break the outer loop backward pass when gradient checkpointing is enabled?**

Right now, MAML + checkpointing and MAML + learnable lr both work, but all 3 together do not. Note that I mean checkpointing at each layer within the model itself, **not** checkpointing between MAML inner loop steps; e.g. my model is:

```
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.modules = nn.ModuleList([...])
def forward(self, x):
for m in self.modules:
x = torch.util.checkpoint(m, x)
return x
```

This is a bit tedious because autograd and checkpointing are incompatible, but I think I’ve got MAML working with only `autograd.backward`

. I’ve done some tests against higher, and I think it’s correct.

This is the code for my inner loop; the rest of my implementation is below.

```
def mml_inner(model: nn.Module, loss: torch.tensor, inner_lrs: List[torch.nn.Parameter] = None,
inner_lr: float = 1e-2, scaler: torch.cuda.amp.GradScaler = None):
if scaler is not None:
loss = scaler.scale(loss)
# Create graph for double backward in outer loop
loss.backward(create_graph=not (model.mml_first_order or model.mml_eval))
# We need to track these so we can delete their grads later to prevent a retain
# cycle from backward(create_graph=True)
model.mml_ctx.extend(list(model.parameters()))
grad_scale_inv = 1./scaler.get_scale() if scaler is not None else 1.
with autocast(enabled=scaler is not None):
for idx, m in enumerate(model.mml_modules):
m_params = vars(m)['_parameters']
for name, param in m_params.items():
if param is None:
continue
# if inner_lrs (learnable learning rates) aren't passed, default to scalar inner_lr
lr = F.softplus(inner_lrs[idx]) if inner_lrs is not None else inner_lr
update = lr * grad_scale_inv * param.grad # <--- Anomaly detection picks out this line
new_param = param - update
if not model.mml_first_order:
new_param.retain_grad()
m_params[name] = new_param
```

When doing the outer loop backward pass with checkpointing and learnable learning rates, I get an in-place operation Exception:

`RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).`

This exception only occurs when I run with checkpointing and learnable learning rates together; everything is fine with only one or the other. If I pass `None`

for `inner_lrs`

, or simply pass a list of Parameters with `requires_grad=False`

, everything works fine (but of course, we don’t get learnable learning rates).

Maybe this is because I don’t fully understand autograd, but it’s kind of weird to me that the error triggers for an in-place update to the inner loop parameter **only when it’s multiplied with another tensor that requires grad**. I assume this has to do with how PyTorch tensor versioning works. I’m also frankly a bit confused why this code works at all; why does the gradient get propagated through the fast weights in the inner loop at all with checkpointing, since they aren’t inputs to the checkpointed sub-modules?

For context, here is the rest of my MAML code:

```
def _mml_set_requires_grad(model, skip_modules, requires_grad):
for m in model.modules():
for p in vars(m)['_parameters'].values():
if p is not None:
if m not in skip_modules:
p.requires_grad = requires_grad
else:
p.requires_grad = True
def mml_ctx(m, inner_modules, first_order: bool = False, grad_clip: float = 1e9, eval_ = False):
m_ = copy.deepcopy(m)
m_.mml_m0 = lambda: m
m_.mml_first_order = first_order
m_.mml_propagate_grad = False
m_.mml_grad_propagation_hooks = []
m_.mml_modules = inner_modules(m_)
m_.mml_ctx = []
m_.mml_grad_clip = grad_clip
m_.mml_eval = eval_
# This hack turns off grads for non-adaptation parameters to save memory/compute
_mml_set_requires_grad(m, inner_modules(m), False)
_mml_set_requires_grad(m_, m_.mml_modules, False)
for p, p_ in zip(m.parameters(), m_.parameters()):
if not p_.requires_grad:
continue
def create_hook(model, p):
def hook(grad):
if model.mml_propagate_grad:
if p.grad is None:
p.grad = grad.detach()
else:
p.grad = (p.grad + grad).detach()
return hook
m_.mml_grad_propagation_hooks.append(p_.register_hook(create_hook(m_, p)))
return m_
def mml_flush(model):
for h in model.mml_grad_propagation_hooks:
h.remove()
# We're done with this model (MAML-wise)
for p in model.mml_ctx:
del p.grad
del model.mml_grad_propagation_hooks
del model.mml_ctx
def mml_outer_backward(model, outer_loss, scaler = None):
# Unlike in the inner loop, we want to copy grads
# back to the original model now
model.mml_propagate_grad = True
if scaler is not None:
outer_loss = scaler.scale(outer_loss)
outer_loss.backward()
mml_flush(model)
def mml_outer_step(model, opt, scaler = None):
# Update the original/pre-adaptation parameters
if scaler is not None:
scaler.unscale_(opt)
grad = torch.nn.utils.clip_grad_norm_(model.mml_m0().parameters(), model.mml_grad_clip)
scaler.step(opt)
scaler.update()
else:
grad = torch.nn.utils.clip_grad_norm_(model.mml_m0().parameters(), model.mml_grad_clip)
opt.step()
opt.zero_grad()
del model.mml_m0
return grad
```

It’s designed to be used as

```
def _inner_modules(m):
return list(m.modules())[-40:]
def _run_mml(model, opt, scaler, half, x, x_, first_order):
for i in range(x.shape[0]):
inner_model = mml_ctx(model, _inner_modules, first_order=first_order)
# Inner loop, 2 gradient steps just for testing
for step in range(2):
with autocast(enabled=half):
pred = inner_model(x[i])
loss = -pred.mean()
# This adapts the model in-place
mml_inner(inner_model, loss, scaler=scaler)
# Outer loop
with autocast(enabled=half):
pred_ = inner_model(x_[i])
loss_ = pred_.pow(2).mean()
print(f'outer {i}: {loss_}')
mml_outer_backward(inner_model, loss_, scaler=scaler)
mml_outer_step(inner_model, opt, scaler=scaler)
```

Any insight would be greatly appreciated!