during training my model i am making some of the layers not trainable via:
for param in model.parameters():
param.requires_grad = False
however after checking the parameters i see there are a lot of parameters that still train and change such as:
extras.0.conv.7.running_var
extras.1.conv.1.running_mean
extras.1.conv.1.running_var
extras.1.conv.4.running_mean
extras.1.conv.4.running_var
extras.1.conv.7.running_mean
extras.1.conv.7.running_var
extras.2.conv.1.running_mean
extras.2.conv.1.running_var
extras.2.conv.4.running_mean
extras.2.conv.4.running_var
extras.2.conv.7.running_mean
extras.2.conv.7.running_var
extras.3.conv.1.running_mean
extras.3.conv.1.running_var
extras.3.conv.4.running_mean
extras.3.conv.4.running_var
extras.3.conv.7.running_mean
extras.3.conv.7.running_var
after searching a lot i notice these are batch norm parameters.
How can i freez them or in other word make them requires_grad =False
Following is a toy example that shows requires_grad = False
wont work correctly:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import os
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.convs = nn.ModuleList([ nn.Conv2d(3,6,3),
nn.BatchNorm2d(6),
nn.Conv2d(6, 10, 3),
nn.Conv2d(10, 10, 3) ])
self.fcs = nn.Sequential(nn.Linear(320, 10),
nn.ReLU(),
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1))
def forward(self, x):
x = self.convs[0](x)
x = self.convs[1](x)
x = self.convs[2](x)
x = self.convs[3](x)
x = x.view(-1,)
# print(x.size())
x = self.fcs(x)
return x
model = Net()
loss = nn.L1Loss()
target = Variable(torch.ones(1))
for name, param in model.named_parameters():
if name == 'convs.0.bias' or name=='fcs.2.weight':
param.requires_grad = True
else:
param.requires_grad = False
old_state_dict = {}
for key in model.state_dict():
old_state_dict[key] = model.state_dict()[key].clone()
print(old_state_dict.keys())
optimizer = optim.SGD(filter(lambda p: p.requires_grad,model.parameters()), lr=0.001)
for epoch in range(5):
X = Variable(torch.rand(2,3,10,10))
out = model(X)
output = loss(out, target)
output.backward()
optimizer.step()
new_state_dict = {}
for key in model.state_dict():
new_state_dict[key] = model.state_dict()[key].clone()
# Compare params
count = 0
for key in old_state_dict:
if not (old_state_dict[key] == new_state_dict[key]).all():
print('Diff in {}'.format(key))
count += 1
print(count)
out put:
dict_keys(['convs.0.weight', 'convs.0.bias', 'convs.1.weight', 'convs.1.bias', 'convs.1.running_mean', 'convs.1.running_var', 'convs.2.weight', 'convs.2.bias', 'convs.3.weight', 'convs.3.bias', 'fcs.0.weight', 'fcs.0.bias', 'fcs.2.weight', 'fcs.2.bias', 'fcs.4.weight', 'fcs.4.bias'])
Diff in convs.1.running_mean
Diff in convs.1.running_var
Diff in fcs.2.weight
3