Closure() that re-evaluates the model with new weights?

How to make closure() that re-evaluates the model with new weights?
Any comment on my attempt below? (modified from here)

for input, target in dataset:
    def closure(new_weight_idx=None, new_weight=None):
        # update the model with the new_weight
        if (new_weight_idx is not None) and (new_weight is not None):
             state_dict = model.state_dict()
             key = list(state_dict.keys())[new_weight_idx]
             old_weight = state_dict[key]
             state_dict[key] = new_weight
             model.load_state_dict(state_dict)
        
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()

        # put the old_weight back
        # TODO

        return loss

    optimizer.step(closure)

Use case: try to implement Algorithm 7.1 (Line Search Newton–CG) from Nocedal’s num opt book :slight_smile:

I don’t really understand your question.
Could you explain your use case a bit more?
Is the closure working or do you have some issues with it?

@ptrblck I do apologize, I have updated my question.
Previously I accidentally hit the tab button, then it was posted right away, thank you.

Currently the function is passed into .step() and called without arguments.
See this line of code.

However, you could pass arguments using lambda and most likely partial.
Here is a small example:

model = nn.Linear(20, 2)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=1.)
input = torch.randn(1, 20)
target = torch.empty(1, dtype=torch.long).random_(2)
optimizer.step(lambda: closure(0, torch.randn(size=model.weight.size())))

I cannot comment of the correctness of the approach, since the paper is behind a pay wall and I’m currently out of office, so I cannot see it.