I have a question regarding the amp scaler working with multiple optimisers that step unevenly. The patterns looks as follows:
def do_train_step(wrapper, model_optimiser, contraint_optimiser, inputs): # wrapper contains two nn.Modules: model and constraint, # that both have their own optimiser: model_optimiser, constraint_optimiser # and both interact with the inputs to form the loss with torch.set_grad_enabled(True): with autocast(enabled=True): loss = wrapper(inputs) # Gradients accumulate in here scaler.scale(loss).backward() # The model optimsier should only update once every < accumulate_n_batches_grad > if (global_step + 1) % accumulate_n_batches_grad == 0: # GRADIENT CLIPPING if gradient_clipping: # Unscales the loss of optimizer's assigned params in-place scaler.unscale_(model_optimiser) # Since the gradients of optimizer's assigned params are unscaled, clips as usual: torch.nn.utils.clip_grad_norm_(wrapper.model.parameters(), 1.0) # OPTIMISER STEP (first unscaled) scaler.step(model_optimiser) # AMP SCALER UPDATE scaler.update() # Zero the gradients only here model_optimiser.zero_grad() # LR Scheduler lr_scheduler.step() # The constraint optimiser should update every train step scaler.step(constraint_optimiser) constraint_optimiser.zero_grad() return loss
But, this alternating pattern seems to break the amp scaling. I get the following error:
File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/cuda/amp/grad_scaler.py", line 302, in step raise RuntimeError("step() has already been called since the last update().") RuntimeError: step() has already been called since the last update().
which is about the
If there is anyone that spots what might be going wrong here that would be very helpful.
Thanks in advance!