You could try to load the backward
state_dict before executing the
backward operation, but it’s quite a hacky way:
model = nn.Sequential(
nn.Linear(1, 1, bias=False),
nn.Linear(1, 1, bias=False)
sd_forward = copy.deepcopy(model.state_dict())
sd_backward = copy.deepcopy(sd_forward)
# one train step
output = model(torch.ones(1, 1))
Also note, that the last gradient is wrong, since the output was calculated using the old weights.
Would this approach work for you or did I misunderstand your question?