Finding the total number of trainable parameters in a Graph

Suppose we have @Tudor_Berariu code in Manually feeding trainable parameters to the optimizer :

import torch
import torch.optim as optim
from torch.autograd import Variable

w = Variable(torch.randn(3, 5), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)
x = Variable(torch.randn(5))

optimizer = optim.SGD([w,b], lr=0.01)

optimizer.zero_grad()
y = torch.mv(w, x) + b
y.backward(torch.randn(3))
optimizer.step()

Is there an automated way to find the total number of trainable parameters used to construct y? In the above the total number of trainable parameters used to construct y will be the sum of the total number of trainable of parameters in w and the total number of trainable parameters in b i.e. 3*5 + 3 = 18.

I tried this myself. My conclusion is that when you wrap parameters directly as Variables with requires_grad=True you cannot always distinguish between model parameters and inputs. Sometimes you might need to compute gradients for other reasons like training a contractive autoencoder or searching for adversarial examples. Therefore using nn.Parameter and isinstance(var, nn.Parameter) to check variables seems to be a better approach.

There are some properties of variables like creator or previous_functions that you might exploit to go back through the computational graph, but I’m not sure it’s a robust approach. Here’s an example that works in some cases, but it fails for convolutional layers for example.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Parameter
from torch.autograd import Variable

x1 = Variable(torch.rand(10, 7))
x2 = Variable(torch.rand(10, 9))

l1 = nn.Linear(7, 5)
l2 = nn.Linear(9, 5)
l3 = nn.Linear(10, 5)

y = l3(torch.cat([l1(x1), l2(x2)], 1))

def get_all_params(var, all_params):
    if isinstance(var, Parameter):
        all_params[id(var)] = var.nelement()
    elif hasattr(var, "creator") and var.creator is not None:
        if var.creator.previous_functions is not None:
            for j in var.creator.previous_functions:
                get_all_params(j[0], all_params)
    elif hasattr(var, "previous_functions"):
        for j in var.previous_functions:
            get_all_params(j[0], all_params)

all_params = {}
get_all_params(y, all_params)
print(sum(all_params.values()))
5 Likes