# Finding the total number of trainable parameters in a Graph

Suppose we have @Tudor_Berariu code in Manually feeding trainable parameters to the optimizer :

``````import torch
import torch.optim as optim

x = Variable(torch.randn(5))

optimizer = optim.SGD([w,b], lr=0.01)

y = torch.mv(w, x) + b
y.backward(torch.randn(3))
optimizer.step()
``````

Is there an automated way to find the total number of trainable parameters used to construct `y`? In the above the total number of trainable parameters used to construct `y` will be the sum of the total number of trainable of parameters in `w` and the total number of trainable parameters in `b` i.e. 3*5 + 3 = 18.

I tried this myself. My conclusion is that when you wrap parameters directly as `Variable`s with `requires_grad=True` you cannot always distinguish between model parameters and inputs. Sometimes you might need to compute gradients for other reasons like training a contractive autoencoder or searching for adversarial examples. Therefore using `nn.Parameter` and `isinstance(var, nn.Parameter)` to check variables seems to be a better approach.

There are some properties of variables like `creator` or `previous_functions` that you might exploit to go back through the computational graph, but I’m not sure it’s a robust approach. Here’s an example that works in some cases, but it fails for convolutional layers for example.

``````import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Parameter

x1 = Variable(torch.rand(10, 7))
x2 = Variable(torch.rand(10, 9))

l1 = nn.Linear(7, 5)
l2 = nn.Linear(9, 5)
l3 = nn.Linear(10, 5)

y = l3(torch.cat([l1(x1), l2(x2)], 1))

def get_all_params(var, all_params):
if isinstance(var, Parameter):
all_params[id(var)] = var.nelement()
elif hasattr(var, "creator") and var.creator is not None:
if var.creator.previous_functions is not None:
for j in var.creator.previous_functions:
get_all_params(j, all_params)
elif hasattr(var, "previous_functions"):
for j in var.previous_functions:
get_all_params(j, all_params)

all_params = {}
get_all_params(y, all_params)
print(sum(all_params.values()))
``````
5 Likes