You can do this:
import torch
@torch.no_grad
def step_with_update_clip_(optimizer, value=1.0):
# extract parameters from optimizer
params: list[torch.Tensor] = [p for g in optimizer.param_groups for p in g['params'] if p.grad is not None]
params_before = [p.clone() for p in params] # store parameters before step
optimizer.step() # update parameters
# update is difference in parameters before and after step, apply clipping to it
clipped_update = [torch.clip(p - p_before, -value, value) for p, p_before in zip(params, params_before)]
# revert parameters and add clipped update instead
for p, p_before, u in zip(params, params_before, clipped_update):
p.set_(p_before + u)
Replace optimizer.step()
with step_with_update_clip_(optimizer)
.
Example of how to use it:
model = torch.nn.Linear(2,2)
optimizer = torch.optim.Adam(model.parameters(), 1e-2)
criterion = torch.nn.MSELoss()
inputs = torch.randn(10, 2)
targets = torch.randn(10, 2)
for _ in range(1000):
preds = model(inputs)
loss = criterion(preds, targets)
optimizer.zero_grad()
loss.backward()
step_with_update_clip_(optimizer, value=1.0)
print(loss)
That does require extra clone and subtract operations but they have very little overhead compared to total number of operations performed in the update rule.
Also I am developing a library GitHub - inikishev/torchzero: Modular optimization library for PyTorch. which does exactly that, you can chain various gradient transformations, for example a typical AdamW with gradient clipping is
[ClipValue(2), Adam(), WeightDecay(1e-2), LR(1e-3)]
but you could apply clipping to Adam’s update by moving ClipValue after Adam:
[Adam(), ClipValue(2), WeightDecay(1e-2), LR(1e-3)]
And there is no redundant cloning and subtracting.
As of now the library is not fully finished and basically undocumented, so although it is fully useable, I don’t “advertise” it too much, but I am actually quite close to finishing it