Why is closure not supported in GradScaler ?

In the step function of GradScaler, if a closure is given as a member of kwargs, RuntimeError occurs by the following codes.

The error message says “not currently supported”. Are there any plans to support this feature? Or Would you tell me reasons why closure isn’t supported? If possible, I want to try to make a patch to solve this issue.

My motivation is to support SAM with native amp easily.

Making closures work with dynamic gradient scaling (specifically, the fact that dynamic gradient scaling occasionally skips optimizer.step() if any grads were inf/nan) is tricky, and we haven’t heard any use cases that absolutely needed it (LBFGS is the only one I’m aware of and no one’s asked for that).

Can you implement SAM without a closure like this?

Thank you for your reply!
As you pointed, I use davda54’s SAM implementation without closure.
And it works well.

But I use pytorch-lightning. In pytorch-lightning environment, I have to write some extra codes to calculate gradients two times. My motivation is to solve this problem. I want to use SAM as other optimizers.

I have understood that closure support is tricky. I decided to use SAM without closure. Thank!!

Hi, I have the same problem you did. How did you code fp16 support without closure successfully using the linked github repo? It seems to me that we have to call scaler.step(optimizer), and this in turn will call optimizer.step() with no closure parameter. To use the library, we need to call first_step() and second_step(), so I don’t see how it is possible to use this implementation and float16 support.

Hi, I don’t remember clearly but I’m sure I used codes from pytorch-lightning’s example as follows in README.md.

def training_step(self, batch, batch_idx):
    optimizer = self.optimizers()

    # first forward-backward pass
    loss_1 = self.compute_loss(batch)
    self.manual_backward(loss_1, optimizer)

    # second forward-backward pass
    loss_2 = self.compute_loss(batch)
    self.manual_backward(loss_2, optimizer)

    return loss_1