I have a toy LeNet-5 CNN architecture as:
class LeNet5(nn.Module):
def __init__(self):
# def __init__(self, beta = 1.0):
super().__init__()
# Trainable parameter for swish activation function-
# self.beta = nn.Parameter(torch.tensor(beta, requires_grad = True))
self.conv1 = nn.Conv2d(
in_channels = 1, out_channels = 6,
kernel_size = 5, stride = 1,
padding = 0, bias = False
)
self.bn1 = nn.BatchNorm2d(num_features = 6)
self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
self.conv2 = nn.Conv2d(
in_channels = 6, out_channels = 16,
kernel_size = 5, stride = 1,
padding = 0, bias = False
)
self.bn2 = nn.BatchNorm2d(num_features = 16)
self.fc1 = nn.Linear(
in_features = 256, out_features = 120,
bias = False
)
self.bn3 = nn.BatchNorm1d(num_features = 120)
self.fc2 = nn.Linear(
in_features = 120, out_features = 84,
bias = False
)
self.bn4 = nn.BatchNorm1d(num_features = 84)
self.fc3 = nn.Linear(
in_features = 84, out_features = 10,
bias = True
)
def swish_fn(self, x):
return x * torch.sigmoid(x * self.beta)
def forward(self, x):
x = nn.SiLU()(self.pool(self.bn1(self.conv1(x))))
x = nn.SiLU()(self.pool(self.bn2(self.conv2(x))))
x = x.view(-1, 256)
x = nn.SiLU()(self.bn3(self.fc1(x)))
x = nn.SiLU()(self.bn4(self.fc2(x)))
x = self.fc3(x)
return x
model = LeNet5().to(device)
@torch.no_grad()
def init_weights(m):
# print(m)
if type(m) == nn.Conv2d:
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.fill_(1.0)
elif type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.fill_(1.0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.fill_(1.0)
if m.bias is not None:
m.bias.fill_(1.0)
return None
model.apply(init_weights)
Looking at the keys in the state_dict() gives:
model.state_dict().keys()
"""
odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'conv2.weight', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'fc1.weight', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var', 'bn3.num_batches_tracked', 'fc2.weight', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var', 'bn4.num_batches_tracked', 'fc3.weight', 'fc3.bias'])
"""
However, when I access the named parameters, I get the following output:
for name_m, params in model.named_parameters():
print(name_m, params.size())
"""
conv1.weight torch.Size([6, 1, 5, 5])
bn1.weight torch.Size([6])
bn1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
bn2.weight torch.Size([16])
bn2.bias torch.Size([16])
fc1.weight torch.Size([120, 256])
bn3.weight torch.Size([120])
bn3.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
bn4.weight torch.Size([84])
bn4.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
"""
The difference of keys seem to be:
set(model.state_dict().keys()) - set(model_d.keys())
"""
{'bn1.num_batches_tracked',
'bn1.running_mean',
'bn1.running_var',
'bn2.num_batches_tracked',
'bn2.running_mean',
'bn2.running_var',
'bn3.num_batches_tracked',
'bn3.running_mean',
'bn3.running_var',
'bn4.num_batches_tracked',
'bn4.running_mean',
'bn4.running_var'}
"""
How can I access these in a loop without resorting to hard-coded one liners such as:
model.bn1.running_mean
model.bn1.running_var
I want to access batch-norms’s members, viz., num_batches_tracked, running_mean, running_var, num_batches_tracked