Loading the state_dict
will manipulate the parameters inplace and will cause the failure:
model.load_state_dict({name: weights_original[name] for name in weights_original})
loss.backward()
Changing the parameters before calling the backward
operation will make the forward activations stale and would otherwise compute wrong gradients.
This post explains the issue in more detail for a GAN training, which has however the same root cause of creating stale forward activations.