Convert tensor to Parameter (while keeping the graph)

I want to create a neural network generating the parameters of another NN. The NN ‘A’ generates ‘B’.

But there is a problem: the output of ‘A’ is a ‘tensor’ (with requires_grad = True) while the parameters of ‘B’ are ‘Parameter’. Since I want to backward the gradient up to the parameters of ‘A’, I am looking for a conversion tensor -> Parameter that keeps the computational graph.

A naive conversion ‘w = Parameter(y)’, where w is the generated parameter of ‘B’, and y is the output of ‘A’, does not work: the computational graph is cut between w and y (the gradient of the parameters of ‘A’ are None, while the gradient of w, which should be non-leaf, is not None).

2 Likes

It depends on what ‘B’ does. Have you looked into torch.nn.functional APIs?

‘B’ is a NN with fully connected layers and convolutional layers, whose parameters are of type ‘Parameter’.
So, I could replace the usual layers by layers using torch.nn.functional functions… but I need to rewrite each layer I want to use, then. Is this the only solution?

if you want the network ‘A’ to predict the parameters of network ‘B’, this could be the only easier way than meddling around with default nn layers.

Just responding to the actual title of the question. What is wrong with using nn.Parameter(old_tensor)

where old_tensor is of type torch.Tensor. Does that not work for you? Why does it not?

As I mentioned, new_tensor = nn.Parameter(old_tensor) breaks the computational graph of autograd.

In other words, new_tensor does not contain any information about the way old_tensor has been computed, which I want.

2 Likes

Hi @Zen did you ever find a solution to this problem? I am now facing exactly the same issue.

For one that is still looking for a solution, you might find this thread useful. There are 2 ways to do it:

  1. Delete the weights of each layer manually and re-assign by the tensor of interst (with history)
  2. Use a package from GitHub - SsnL/PyTorch-Reparam-Module: Reparameterize your PyTorch modules, but note that it might not work with batchnorm layer when track_running_stats=True.

Another useful package that is supported directly from PyTorch is functorch with make_functional.