I have a question regarding the amp scaler working with multiple optimisers that step unevenly. The patterns looks as follows:
def do_train_step(wrapper, model_optimiser, contraint_optimiser, inputs):
# wrapper contains two nn.Modules: model and constraint,
# that both have their own optimiser: model_optimiser, constraint_optimiser
# and both interact with the inputs to form the loss
with torch.set_grad_enabled(True):
with autocast(enabled=True):
loss = wrapper(inputs)
# Gradients accumulate in here
scaler.scale(loss).backward()
# The model optimsier should only update once every < accumulate_n_batches_grad >
if (global_step + 1) % accumulate_n_batches_grad == 0:
# GRADIENT CLIPPING
if gradient_clipping:
# Unscales the loss of optimizer's assigned params in-place
scaler.unscale_(model_optimiser)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(wrapper.model.parameters(), 1.0)
# OPTIMISER STEP (first unscaled)
scaler.step(model_optimiser)
# AMP SCALER UPDATE
scaler.update()
# Zero the gradients only here
model_optimiser.zero_grad()
# LR Scheduler
lr_scheduler.step()
# The constraint optimiser should update every train step
scaler.step(constraint_optimiser)
constraint_optimiser.zero_grad()
return loss
But, this alternating pattern seems to break the amp scaling. I get the following error:
File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/cuda/amp/grad_scaler.py", line 302, in step
raise RuntimeError("step() has already been called since the last update().")
RuntimeError: step() has already been called since the last update().
which is about the scaler.step(constraint_optimiser).
If there is anyone that spots what might be going wrong here that would be very helpful.
I think you could either use two different scalers for each optimizer or alternatively use scaler.step() on the constraint_optimizer, if the inner block wasn’t executed.
After the inner code was already executed, the gradients should be already unscaled, which are used in the first optimizer.
However, I’m currently unsure about your use case and don’t know, if both optimizers are using a different subset of all parameters, are partially overlapping, or reuse the same parameters.
Could you explain the use case a bit more?
Also, I think you have a typo in the posted code snippet, as you are calling scaler.step(loss), while an optimizer would be needed.
Hi @ptrblck, thanks for your answer. You are right about the typo, I corrected it for clarity.
A bit more context:
The wrapper wraps two nn.Modules, the constraint and the model. They do not have any shared parameters. They do both interact with (parts of) the total loss. The loss composition looks something like:
The model optimises with gradients from the total_loss, while the constraint optimises only part of the loss (loss_2). I hope this makes it clearer.
On your suggested solutions:
1. Using two scalers.
Would the following be valid?
model_scaled_loss = model_scaler.scale(loss)
constraint_scaled_loss = constraint_scaler.scale(loss)
loss.backward()
if (global_step + 1) % accumulate_n_batches_grad == 0:
# OPTIMISER STEP (first unscaled)
model_scaler.step(model_optimiser)
# AMP SCALER UPDATE
model_scaler.update()
# Zero the gradients only here
model_optimiser.zero_grad()
constraint_scaler.step(constraint_optimiser)
constraint_optimiser.zero_grad()
2. use scaler.step() on the constraint_optimizer , if the inner block wasn’t executed.
I am not sure what you mean by this. Are you saying only to execute scaler.step(constraint_optimiser) if scaler.step(model_optimiser) is not called (inner block execution)? I would think that both optimisers need explicit unscaling of the gradients, or don’t they? If both optimisers do not share parameters, would scaler.unscale_(model_optimiser) affect scaler.unscale_(constraint_optimiser) in any way then?
Hopefully this is clear enough, otherwise I can provide more information.
Hey, any updates on this issue? I am also in a similar situation with two optimisers and mixed precision training. Getting the same error. Not sure why i am getting that error as well. @ptrblck@ClaartjeBarkhof