After reading about the optimizer from the paper “Sharpness-Aware Minimization for Efficiently Improving Generalization,” I’ve been interested in trying this optimizer with pytorch. There is an unofficial implementation at this repo. It wraps an optimizer
This optimizer doesn’t implement the step()
function, it implements a first_step()
followed by a second_step()
.
With a grad_scaler
I might have training code like this:
optimizer.zero_grad()
batch = batch.to(device)
label = label.to(device)
with autocast():
pred = model(batch)
loss = criterion(pred, label)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
I’m trying to understand how this would work with a custom optimizer like SAM. This implementation wraps a base optimizer like SGD and needs two forward and backward passes before the base optimizer can actually update the model. My thought was to handle everything manually like so: Instead of letting the SAM optimizer class call step()
on the optimizer it wraps, I do it myself, and manually unscale the weights.
# Set up base optimizer and SAM, which wraps it
base_optimizer = torch.optim.SGD(trainable_layer_params, 1e-3, momentum=0.9)
optimizer = SAM(trainable_layer_params, base_optimizer, lr=0.1, momentum=0.9)
...
# Model training step
optimizer.zero_grad()
with autocast():
loss = criterion(model(batch), label)
scaler.scale(loss).backward()
scaler.unscale_(base_optimizer)
optimizer.first_step(zero_grad=True) # Ascend model params to local maximum
scaler.update(1.0)
with autocast():
loss = criterion(model(batch), label)
scaler.scale(loss).backward()
scaler.unscale_(base_optimizer)
optimizer.second_step() # Descend from local maximum
scaler.step(base_optimizer)
scaler.update()
scheduler.step()
However, this results in an error:
34 loss = criterion(model(batch), label)
35 scaler.scale(loss).backward()
---> 36 scaler.unscale_(base_optimizer)
37 optimizer.first_step(zero_grad=True)
38 scaler.update(1.0)
/home/user/.local/lib/python3.6/site-packages/torch/cuda/amp/grad_scaler.py in unscale_(self, optimizer)
256
257 if optimizer_state["stage"] is OptState.UNSCALED:
--> 258 raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
259 elif optimizer_state["stage"] is OptState.STEPPED:
260 raise RuntimeError("unscale_() is being called after step().")
RuntimeError: unscale_() has already been called on this optimizer since the last update().
What would be the way to handle the unscaling of the gradients when they need to be scaled twice before updating the model? More broadly, how to approach these kinds of issues with a custom optimizer?