Straight Through Estimator on Model Parameters

I have been working on a PyTorch implementation of this paper and am implementing the Straight-through estimator algorithm on page 26.

Currently, I have the following code to calculate the midpoint parameters and store them in pi_model_state_dict and load this into a blank model called pi_model via load_state_dict:

for (
    (n, p_a),
) in zip(params_model.values(), params_a.items(), proj_params.values()):
        n, type(p_m), type(p_a), type(p_p)
    )  # parameter, parameter, tensor
    pi_model_state_dict[n] = 0.5 * (p_a + ((p_p - p_m).detach() + p_m))
    print(n, type(pi_model_state_dict[n]))  # tensor!

optimizer.zero_grad() # optimizer = t.optim.Adam(model.parameters(),
output = pi_model(data)
loss = loss_fn(output, target)
for p in model.parameters():
    print(f"grad time: {p.grad}") # All None!

params_model = {k: v for k, v in model.named_parameters()} where model is a different module with same architecture as pi_model but the weights are loaded from a checkpoint. proj_params is the result of a non differentiable operation performed on params in model.parameters().

Ideally, I want the computational graph to be that I can calculate the gradient from the loss on pi_model and see the gradients on model.parameters() but this doesn’t seem to be working since the .grad attribute on each parameter in model.parameters() is None. Is there any other attempt at a straight through estimator I can try to get this to work? Thank you so much!

Solved the issue, using torch.func.functional_call to change the functionality as follows worked:

for (
    (n, p_a),
) in zip(
    param_model.values(), param_a.items(), proj_params.values()
    pi_model_state_dict[n] = 0.5 * (p_a + ((p_p - p_m).detach() + p_m))

output = t.func.functional_call(pi_model, pi_model_state_dict, data)