Get 1D tensor of all trainable parameters

Hello!

In Torch, I could use the following command:

cnn_params, cnn_grad_params = cnn:getParameters()

to get a 1D tensor of all the trainable parameters of a given model (and corresponding gradients).
Is it possible to do something like that in PyTorch so that cnn_params shares the same memory of the corresponding model? I should mention that I only care about the trainable parameters (i.e., weights and biases) and not their gradients.

Thanks a lot for any help!

You can get the parameters by calling model.parameters(). This will return a generator, so if you need a list just call list(model.parameters()).

1 Like

Thanks for your answer ptrblck. My idea, if possible, would be to manipulate a 1D tensor with all the parameters so that I can avoid looping through the list and perform the same operation several times (e.g., replacing parameters above a certain value in each individual module) in order to speed up the algorithm.

Thanks again!

torch.cat([param.view(-1) for param in model.parameters()])

Should work for you.

1 Like

Thank you. I tried the solution that you suggested but it seems that this object does not share the same memory as the initial model.

Here is an example code:

net = nn.ModuleList()
net.append(nn.Conv2d(3, 8, 3))
net.append(nn.Conv2d(8, 8, 3))
net.append(nn.Conv2d(8, 1, 3)) 

unrolled = torch.cat([param.view(-1) for param in net.parameters()])
unrolled.fill_(3.14)
print(net[0].weight[0,0,0,0])
print(unrolled[0])

Outputs:
tensor(0.0684, grad_fn=SelectBackward)
tensor(3.1400, grad_fn=SelectBackward)

In case anyone else is interested, the closest solution I was able to find involves using
torch.nn.utils.parameters_to_vector() to get the parameters vector and then calling torch.nn.utils.vector_to_parameters() when Iā€™m done modifying it.

5 Likes

May be this link will give a proper explanation: http://cat2.mit.edu/dh_upload/backup/transfer/source/pytorch/test/test_nn.py

from torch.nn.utils import parameters_to_vector, vector_to_parameters

1 Like

Also, this link may be a help: