For one that is still looking for a solution, you might find this thread useful. There are 2 ways to do it:
- Delete the weights of each layer manually and re-assign by the tensor of interst (with history)
- Use a package from GitHub - SsnL/PyTorch-Reparam-Module: Reparameterize your PyTorch modules, but note that it might not work with
batchnorm
layer whentrack_running_stats=True
.
Another useful package that is supported directly from PyTorch is functorch
with make_functional
.