You could try to load the backward `state_dict`

before executing the `backward`

operation, but itâ€™s quite a hacky way:

```
model = nn.Sequential(
nn.Linear(1, 1, bias=False),
nn.Linear(1, 1, bias=False)
)
with torch.no_grad():
model[0].weight.fill_(1.)
model[1].weight.fill_(1.)
sd_forward = copy.deepcopy(model.state_dict())
sd_backward = copy.deepcopy(sd_forward)
sd_backward['0.weight'].fill_(10.)
sd_backward['1.weight'].fill_(10.)
# one train step
output = model(torch.ones(1, 1))
model.load_state_dict(sd_backward)
output.mean().backward()
print(model[0].weight.grad)
> tensor([[10.]])
print(model[1].weight.grad)
> tensor([[1.]])
```

Also note, that the last gradient is wrong, since the output was calculated using the old weights.

Would this approach work for you or did I misunderstand your question?