Hi! I’m trying to do MAML + gradient checkpointing + learnable learning rates.
My question is why does multiplying the inner loop gradient with another tensor requiring grad break the outer loop backward pass when gradient checkpointing is enabled?
Right now, MAML + checkpointing and MAML + learnable lr both work, but all 3 together do not. Note that I mean checkpointing at each layer within the model itself, not checkpointing between MAML inner loop steps; e.g. my model is:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.modules = nn.ModuleList([...])
def forward(self, x):
for m in self.modules:
x = torch.util.checkpoint(m, x)
return x
This is a bit tedious because autograd and checkpointing are incompatible, but I think I’ve got MAML working with only autograd.backward
. I’ve done some tests against higher, and I think it’s correct.
This is the code for my inner loop; the rest of my implementation is below.
def mml_inner(model: nn.Module, loss: torch.tensor, inner_lrs: List[torch.nn.Parameter] = None,
inner_lr: float = 1e-2, scaler: torch.cuda.amp.GradScaler = None):
if scaler is not None:
loss = scaler.scale(loss)
# Create graph for double backward in outer loop
loss.backward(create_graph=not (model.mml_first_order or model.mml_eval))
# We need to track these so we can delete their grads later to prevent a retain
# cycle from backward(create_graph=True)
model.mml_ctx.extend(list(model.parameters()))
grad_scale_inv = 1./scaler.get_scale() if scaler is not None else 1.
with autocast(enabled=scaler is not None):
for idx, m in enumerate(model.mml_modules):
m_params = vars(m)['_parameters']
for name, param in m_params.items():
if param is None:
continue
# if inner_lrs (learnable learning rates) aren't passed, default to scalar inner_lr
lr = F.softplus(inner_lrs[idx]) if inner_lrs is not None else inner_lr
update = lr * grad_scale_inv * param.grad # <--- Anomaly detection picks out this line
new_param = param - update
if not model.mml_first_order:
new_param.retain_grad()
m_params[name] = new_param
When doing the outer loop backward pass with checkpointing and learnable learning rates, I get an in-place operation Exception:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
This exception only occurs when I run with checkpointing and learnable learning rates together; everything is fine with only one or the other. If I pass None
for inner_lrs
, or simply pass a list of Parameters with requires_grad=False
, everything works fine (but of course, we don’t get learnable learning rates).
Maybe this is because I don’t fully understand autograd, but it’s kind of weird to me that the error triggers for an in-place update to the inner loop parameter only when it’s multiplied with another tensor that requires grad. I assume this has to do with how PyTorch tensor versioning works. I’m also frankly a bit confused why this code works at all; why does the gradient get propagated through the fast weights in the inner loop at all with checkpointing, since they aren’t inputs to the checkpointed sub-modules?
For context, here is the rest of my MAML code:
def _mml_set_requires_grad(model, skip_modules, requires_grad):
for m in model.modules():
for p in vars(m)['_parameters'].values():
if p is not None:
if m not in skip_modules:
p.requires_grad = requires_grad
else:
p.requires_grad = True
def mml_ctx(m, inner_modules, first_order: bool = False, grad_clip: float = 1e9, eval_ = False):
m_ = copy.deepcopy(m)
m_.mml_m0 = lambda: m
m_.mml_first_order = first_order
m_.mml_propagate_grad = False
m_.mml_grad_propagation_hooks = []
m_.mml_modules = inner_modules(m_)
m_.mml_ctx = []
m_.mml_grad_clip = grad_clip
m_.mml_eval = eval_
# This hack turns off grads for non-adaptation parameters to save memory/compute
_mml_set_requires_grad(m, inner_modules(m), False)
_mml_set_requires_grad(m_, m_.mml_modules, False)
for p, p_ in zip(m.parameters(), m_.parameters()):
if not p_.requires_grad:
continue
def create_hook(model, p):
def hook(grad):
if model.mml_propagate_grad:
if p.grad is None:
p.grad = grad.detach()
else:
p.grad = (p.grad + grad).detach()
return hook
m_.mml_grad_propagation_hooks.append(p_.register_hook(create_hook(m_, p)))
return m_
def mml_flush(model):
for h in model.mml_grad_propagation_hooks:
h.remove()
# We're done with this model (MAML-wise)
for p in model.mml_ctx:
del p.grad
del model.mml_grad_propagation_hooks
del model.mml_ctx
def mml_outer_backward(model, outer_loss, scaler = None):
# Unlike in the inner loop, we want to copy grads
# back to the original model now
model.mml_propagate_grad = True
if scaler is not None:
outer_loss = scaler.scale(outer_loss)
outer_loss.backward()
mml_flush(model)
def mml_outer_step(model, opt, scaler = None):
# Update the original/pre-adaptation parameters
if scaler is not None:
scaler.unscale_(opt)
grad = torch.nn.utils.clip_grad_norm_(model.mml_m0().parameters(), model.mml_grad_clip)
scaler.step(opt)
scaler.update()
else:
grad = torch.nn.utils.clip_grad_norm_(model.mml_m0().parameters(), model.mml_grad_clip)
opt.step()
opt.zero_grad()
del model.mml_m0
return grad
It’s designed to be used as
def _inner_modules(m):
return list(m.modules())[-40:]
def _run_mml(model, opt, scaler, half, x, x_, first_order):
for i in range(x.shape[0]):
inner_model = mml_ctx(model, _inner_modules, first_order=first_order)
# Inner loop, 2 gradient steps just for testing
for step in range(2):
with autocast(enabled=half):
pred = inner_model(x[i])
loss = -pred.mean()
# This adapts the model in-place
mml_inner(inner_model, loss, scaler=scaler)
# Outer loop
with autocast(enabled=half):
pred_ = inner_model(x_[i])
loss_ = pred_.pow(2).mean()
print(f'outer {i}: {loss_}')
mml_outer_backward(inner_model, loss_, scaler=scaler)
mml_outer_step(inner_model, opt, scaler=scaler)
Any insight would be greatly appreciated!