Hi,
I have a very specific use case of PyTorch and I am not sure what would be the suggested way to implement it. So, I will first describe what I want to do. Then, how I attempted to do it.
Goal: Let say I have a model called net which has trainable parameters from those I am interested in say only in convolutional layers:
for module in net.modules():
if is_conv_layer(module):
module.weight # This is the weight I am interested in.
Now, what I want to do is to replace module.weight’s with another set of Parameter’s let us call weight_new where weight_new=weight*weight2 Here, the important thing is that I want to update weight2 values and only use the initial values of module.weight.
Attempt:
for module in net.modules():
if is_conv_layer(module):
data = module.weight.data
module.weight = weight2
module.weight.data *= data
So, basically, I replace trainable parameters with new Parameters weight2 and only after that, I updated .data term so that it is the multiplication of weight and weight2. This is because I cannot modify Parameter before assigning it to module.weight.
When I minimize the losses for these, I realized that weight2 is not really updating itself. Am I doing it wrong, what would be the suggested way of doing this?
If I am not mistaken, I read that overwriting weight.data is not a suggested way to do things in PyTorch and it may be the reason behind the weird behaviors. Is it really the case?