How to use shared weights in different layers of a model

I am trying to share the weights in different layers in one model. Please take a look at this example code:

this code tries to share weights in fc1 and some parts of fc2

import torch
import torch.nn as nn
import torch.optim as optim

class testModule(nn.Module):

    def __init__(self):
        super(testModule, self).__init__()
        self.fc1 = nn.Linear(5, 10, bias=True)
        self.fc2 = nn.Linear(10, 10, bias=False)

        self.shared_weights = nn.Parameter(torch.randn(10, 5), requires_grad=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

    def share_weight(self):
        index = [1, 3, 5, 7, 9]
        self.fc1.weight.data = self.shared_weights
        self.fc2.weight.data[:, index] = self.shared_weights

After calling the .share_weight() method and training, the weight in fc1.weight and fc2.weight[:, index] become different.

Why would this happen and what is the behavior behind assigning weight.data as another tensor. How could I reach the goal of sharing fc1.weight and fc2.weight[:, index] in training.

Hi,

.data is in the process of being removed and should not be used. As you have experienced, it only does very confusing things :smiley:

You will need to have only nn.Parameters to be the true parameters and you will have to recompute other things at each forward:

import torch
import torch.nn as nn
import torch.optim as optim

class testModule(nn.Module):

    def __init__(self):
        super(testModule, self).__init__()
        self.fc1 = nn.Linear(5, 10, bias=True)
        self.fc2 = nn.Linear(10, 10, bias=False)

        # Remove the weights as we override them in the forward
        # so that they don't show up when calling .parameters()
        del self.fc1.weight
        del self.fc2.weight

        self.fc2_base_weights = nn.Parameter(torch.randn(10, 10))
        self.shared_weights = nn.Parameter(torch.randn(10, 5))

    def forward(self, x):
        # Update the weights
        index = [1, 3, 5, 7, 9]
        self.fc1.weight = self.shared_weights
        self.fc2.weight = self.fc2_base_weights.clone()
        self.fc2.weight[:, index] = self.shared_weights

        x = self.fc1(x)
        x = self.fc2(x)
        return x

    def _weight(self):
1 Like

Thank you! This way works!