Number of parameters when using a shared module?

Hello there! I’m experimenting with a network where I reuse several layers that are initialized in the init portion of my network, in my forward pass, so that I can tie the weights together. However, when I use a “parameter counter” like torchsummary, it shows me that I have an identical number of parameters to a net with a unique module for each layer.

My question is, is reusing a layer module enough to tie the weights? This to me sounds like a stupid question because how would backprop have additional parameters if an additional module doesn’t exist, but I was thinking maybe reusing the layer creates some kind of shadow layer that caches the parameters for each individual layer.

Your approach should be correct and I get the same output for this dummy model:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1, 1, bias=False)
        
    def forward(self, x):
        for _ in range(10):
            x = self.fc1(x)
        return x

model = MyModel()
x = torch.randn(1, 1)

summary(model, [(1, 1)])
> ----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                 [-1, 1, 1]               1
            Linear-2                 [-1, 1, 1]               1
            Linear-3                 [-1, 1, 1]               1
            Linear-4                 [-1, 1, 1]               1
            Linear-5                 [-1, 1, 1]               1
            Linear-6                 [-1, 1, 1]               1
            Linear-7                 [-1, 1, 1]               1
            Linear-8                 [-1, 1, 1]               1
            Linear-9                 [-1, 1, 1]               1
           Linear-10                 [-1, 1, 1]               1
================================================================
Total params: 10
Trainable params: 10
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00

so I assume torchsummary might not have implemented this featrure.

1 Like