How to use shared weights in different layers of a model

Hi,

.data is in the process of being removed and should not be used. As you have experienced, it only does very confusing things :smiley:

You will need to have only nn.Parameters to be the true parameters and you will have to recompute other things at each forward:

import torch
import torch.nn as nn
import torch.optim as optim

class testModule(nn.Module):

    def __init__(self):
        super(testModule, self).__init__()
        self.fc1 = nn.Linear(5, 10, bias=True)
        self.fc2 = nn.Linear(10, 10, bias=False)

        # Remove the weights as we override them in the forward
        # so that they don't show up when calling .parameters()
        del self.fc1.weight
        del self.fc2.weight

        self.fc2_base_weights = nn.Parameter(torch.randn(10, 10))
        self.shared_weights = nn.Parameter(torch.randn(10, 5))

    def forward(self, x):
        # Update the weights
        index = [1, 3, 5, 7, 9]
        self.fc1.weight = self.shared_weights
        self.fc2.weight = self.fc2_base_weights.clone()
        self.fc2.weight[:, index] = self.shared_weights

        x = self.fc1(x)
        x = self.fc2(x)
        return x

    def _weight(self):
1 Like