Updatation of Parameters without using optimizer.step()

Hi,

Instead of doing the classical:

pred = model(inp)
loss = critetion(pred, ground_truth)
optimizer.zero_grad()
loss.backward()
optimizer.step()

You can compute the updates by hand and then set them into the weights.
This is basically what SGD does where the update_function return p + lr*p.grad.

pred = model(inp)
loss = your_loss(pred)
model.zero_grad()
loss.backward()
with torch.no_grad():
  for p in model.parameters():
    new_val = update_function(p, p.grad, loss, other_params)
    p.copy_(new_val)
7 Likes