I have been working on a PyTorch implementation of this paper and am implementing the Straight-through estimator algorithm on page 26.
Currently, I have the following code to calculate the midpoint parameters and store them in pi_model_state_dict
and load this into a blank model called pi_model
via load_state_dict
:
for (
p_m,
(n, p_a),
p_p,
) in zip(params_model.values(), params_a.items(), proj_params.values()):
print(
n, type(p_m), type(p_a), type(p_p)
) # parameter, parameter, tensor
pi_model_state_dict[n] = 0.5 * (p_a + ((p_p - p_m).detach() + p_m))
print(n, type(pi_model_state_dict[n])) # tensor!
pi_model.load_state_dict(pi_model_state_dict)
optimizer.zero_grad() # optimizer = t.optim.Adam(model.parameters(), lr=args.lr)
output = pi_model(data)
loss = loss_fn(output, target)
loss.backward()
for p in model.parameters():
print(f"grad time: {p.grad}") # All None!
params_model = {k: v for k, v in model.named_parameters()}
where model
is a different module with same architecture as pi_model
but the weights are loaded from a checkpoint. proj_params
is the result of a non differentiable operation performed on params in model.parameters()
.
Ideally, I want the computational graph to be that I can calculate the gradient from the loss on pi_model
and see the gradients on model.parameters()
but this doesn’t seem to be working since the .grad
attribute on each parameter in model.parameters()
is None. Is there any other attempt at a straight through estimator I can try to get this to work? Thank you so much!