Hi,
Instead of doing the classical:
pred = model(inp)
loss = critetion(pred, ground_truth)
optimizer.zero_grad()
loss.backward()
optimizer.step()
You can compute the updates by hand and then set them into the weights.
This is basically what SGD does where the update_function return p + lr*p.grad
.
pred = model(inp)
loss = your_loss(pred)
model.zero_grad()
loss.backward()
with torch.no_grad():
for p in model.parameters():
new_val = update_function(p, p.grad, loss, other_params)
p.copy_(new_val)