I have been working on a PyTorch implementation of this paper and am implementing the Straight-through estimator algorithm on page 26.

Currently, I have the following code to calculate the midpoint parameters and store them in `pi_model_state_dict`

and load this into a blank model called `pi_model`

via `load_state_dict`

:

```
for (
p_m,
(n, p_a),
p_p,
) in zip(params_model.values(), params_a.items(), proj_params.values()):
print(
n, type(p_m), type(p_a), type(p_p)
) # parameter, parameter, tensor
pi_model_state_dict[n] = 0.5 * (p_a + ((p_p - p_m).detach() + p_m))
print(n, type(pi_model_state_dict[n])) # tensor!
pi_model.load_state_dict(pi_model_state_dict)
optimizer.zero_grad() # optimizer = t.optim.Adam(model.parameters(), lr=args.lr)
output = pi_model(data)
loss = loss_fn(output, target)
loss.backward()
for p in model.parameters():
print(f"grad time: {p.grad}") # All None!
```

`params_model = {k: v for k, v in model.named_parameters()}`

where `model`

is a different module with same architecture as `pi_model`

but the weights are loaded from a checkpoint. `proj_params`

is the result of a non differentiable operation performed on params in `model.parameters()`

.

Ideally, I want the computational graph to be that I can calculate the gradient from the loss on `pi_model`

and see the gradients on `model.parameters()`

but this doesnâ€™t seem to be working since the `.grad`

attribute on each parameter in `model.parameters()`

is None. Is there any other attempt at a straight through estimator I can try to get this to work? Thank you so much!