In the step function of GradScaler, if a closure is given as a member of kwargs, RuntimeError occurs by the following codes.
The error message says “not currently supported”. Are there any plans to support this feature? Or Would you tell me reasons why closure isn’t supported? If possible, I want to try to make a patch to solve this issue.
My motivation is to support SAM with native amp easily.
Making closures work with dynamic gradient scaling (specifically, the fact that dynamic gradient scaling occasionally skips optimizer.step() if any grads were inf/nan) is tricky, and we haven’t heard any use cases that absolutely needed it (LBFGS is the only one I’m aware of and no one’s asked for that).
Can you implement SAM without a closure like this?
Thank you for your reply!
As you pointed, I use davda54’s SAM implementation without closure.
And it works well.
But I use pytorch-lightning. In pytorch-lightning environment, I have to write some extra codes to calculate gradients two times. My motivation is to solve this problem. I want to use SAM as other optimizers.
I have understood that closure support is tricky. I decided to use SAM without closure. Thank!!
Hi, I have the same problem you did. How did you code fp16 support without closure successfully using the linked github repo? It seems to me that we have to call scaler.step(optimizer), and this in turn will call optimizer.step() with no closure parameter. To use the library, we need to call first_step() and second_step(), so I don’t see how it is possible to use this implementation and float16 support.
Hi Alex, I ran into the same problem recently, and figured out a way that uses vanilla torch. You still run scaler.scale(loss).backward(), but instead of calling scaler.step(optimizer), you first unscale the optimizer, check if the grads are non-finite, then call optimizer.first_step() and scaler.update(). You then repeat for the second step. Here’s an example:
def are_grads_finite(params):
grads = [p.grad for p in params if p.grad is not None]
vec = torch.nn.utils.parameters_to_vector(grads)
return bool(torch.isfinite(vec).all())
for input, label in dataset:
# First pass
with torch.amp.autocast():
output = model(input)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if not are_grads_finite(model.parameters()):
optimizer.zero_grad()
scaler.update()
continue # skip batch if grads are non-finite, like in scaler.step()
optimizer.first_step(zero_grad=True)
scaler.update()
# Second pass
with torch.amp.autocast():
output = model(input)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if not are_grads_finite(model.parameters()):
optimizer.zero_grad()
scaler.update()
continue
optimizer.second_step(zero_grad=True)
scaler.update()
Basically you will need to manually do what scaler.step(optimizer) does, since you can’t tell it to run optmizer.first_step() or second_step(). I tested this training some classification model, and they converged without issue.