Pytorch summary - parameters count

I am experimenting with various deep networks and I always want to know how may parameters are involved. I was using pytorch summary however I noticed that if I use many times the same module in the forward pass, its associated parameters are counted multiple times.

An example is this:

class Net(nn.Module):
    def __init__(self,):
        super(Net, self).__init__()
        self.lin = nn.Linear(3,3)

    def forward(self,x):
      x = self.lin(x)
      x = self.lin(x)
      x = self.lin(x)
      x = self.lin(x)
      return x
net = Net()
from torchsummary import summary

where you get 48 Total params, which is 12*4. 12, in this case, is actually the number of trainable parameters of the network.

Thus, my question is, is there a way to make Pytorch summary print out the number of “single” trainable parameters of the model?

Otherwise, I know to use a script like this

model_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([ for p in model_parameters])
print(f"The network has {params} trainable parameters")

to get the desired result, but I like how pytorch summary works.