How to make closure() that re-evaluates the model with new weights?
Any comment on my attempt below? (modified from here)
for input, target in dataset:
def closure(new_weight_idx=None, new_weight=None):
# update the model with the new_weight
if (new_weight_idx is not None) and (new_weight is not None):
state_dict = model.state_dict()
key = list(state_dict.keys())[new_weight_idx]
old_weight = state_dict[key]
state_dict[key] = new_weight
model.load_state_dict(state_dict)
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
# put the old_weight back
# TODO
return loss
optimizer.step(closure)
Use case: try to implement Algorithm 7.1 (Line Search Newton–CG) from Nocedal’s num opt book