Convert tensor to Parameter (while keeping the graph)

For one that is still looking for a solution, you might find this thread useful. There are 2 ways to do it:

  1. Delete the weights of each layer manually and re-assign by the tensor of interst (with history)
  2. Use a package from GitHub - SsnL/PyTorch-Reparam-Module: Reparameterize your PyTorch modules, but note that it might not work with batchnorm layer when track_running_stats=True.

Another useful package that is supported directly from PyTorch is functorch with make_functional.