How to clip the values of an optimizer?

Hey,

to prevent NAN values, a common strategy is to use gradient clipping to cut down all the gradients.

Before the gradient is applied, it is going through an optimizer using e.g. momentum. How can I clip these parameters as well as the final update applied to the weights (efficiently)?

Is there any reason not to clip the optimizer?

I haven’t heard of approaches to clip the internal states of an optimizer, but you could iterate its state and apply your clipping on the desired internal attributes.

You can do this:

import torch

@torch.no_grad
def step_with_update_clip_(optimizer, value=1.0):

    # extract parameters from optimizer
    params: list[torch.Tensor] = [p for g in optimizer.param_groups for p in g['params'] if p.grad is not None]
    params_before = [p.clone() for p in params] # store parameters before step
    optimizer.step() # update parameters

    # update is difference in parameters before and after step, apply clipping to it
    clipped_update = [torch.clip(p - p_before, -value, value) for p, p_before in zip(params, params_before)]

    # revert parameters and add clipped update instead
    for p, p_before, u in zip(params, params_before, clipped_update):
        p.set_(p_before + u)

Replace optimizer.step() with step_with_update_clip_(optimizer).

Example of how to use it:

model = torch.nn.Linear(2,2)
optimizer = torch.optim.Adam(model.parameters(), 1e-2)
criterion = torch.nn.MSELoss()

inputs = torch.randn(10, 2)
targets = torch.randn(10, 2)

for _ in range(1000):
    preds = model(inputs)
    loss = criterion(preds, targets)
    optimizer.zero_grad()
    loss.backward()
    step_with_update_clip_(optimizer, value=1.0)
    print(loss)

That does require extra clone and subtract operations but they have very little overhead compared to total number of operations performed in the update rule.

Also I am developing a library GitHub - inikishev/torchzero: Modular optimization library for PyTorch. which does exactly that, you can chain various gradient transformations, for example a typical AdamW with gradient clipping is

[ClipValue(2), Adam(), WeightDecay(1e-2), LR(1e-3)]

but you could apply clipping to Adam’s update by moving ClipValue after Adam:

[Adam(), ClipValue(2), WeightDecay(1e-2), LR(1e-3)]

And there is no redundant cloning and subtracting.

As of now the library is not fully finished and basically undocumented, so although it is fully useable, I don’t “advertise” it too much, but I am actually quite close to finishing it

This would solve also another question, of how to set it. Thank you for that!