I am experimenting with various deep networks and I always want to know how may parameters are involved. I was using pytorch summary however I noticed that if I use many times the same module in the forward pass, its associated parameters are counted multiple times.
An example is this:
class Net(nn.Module):
def __init__(self,):
super(Net, self).__init__()
self.lin = nn.Linear(3,3)
def forward(self,x):
x = self.lin(x)
x = self.lin(x)
x = self.lin(x)
x = self.lin(x)
return x
net = Net()
from torchsummary import summary
summary(net.to(device),(1,3))
where you get 48 Total params, which is 12*4. 12, in this case, is actually the number of trainable parameters of the network.
Thus, my question is, is there a way to make Pytorch summary print out the number of “single” trainable parameters of the model?
Otherwise, I know to use a script like this
model_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f"The network has {params} trainable parameters")
to get the desired result, but I like how pytorch summary works.