I am trying to freeze a lot of layers except few layers (in below came fc_cat and fc_kin_lin are two linear layers) while keeping the dropout layers on and Batch norm layer turned off. Is this a right way to do it? I know there are different posts with each of them but not together.
def set_bn_eval(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
model.train()
if args.freeze_bn:
model.apply(set_bn_eval)
state = model.state_dict()
# pdb.set_trace()
for name, param in state.items():
if name.find('fc_cat') > 0 or name.find('fc_kin_lin') > 0:
print('\n\n', name, 'layer parameters will be trained\n\n')
else:
print(name, 'layer parameters are being frozen')
if isinstance(param, Parameter):
param.requires_grad = False
pdb.set_trace()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
Replying to my own question
I realized that I could use name_paratmer instead of state_dict
parameter_dict = dict(model.named_parameters())
for name, param in parameter_dict.items():
if name.find('fc_cat')>0 or name.find('fc_kin_lin')>0:
print('\n\n', name, 'layer parameters will be trained\n\n')
else:
if isinstance(param, Parameter):
print(name, 'layer parameters are being frozen')
param.requires_grad = False
That does the job for parameter freezing. BUT, I am getting problem with with BN layer off mode using the above solution gives me nan.
You can use model.apply(function_name) here for batch norm layers
`def set_bn_eval(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()`
then call this function by .apply before starting the training.
model.train() model.apply(set_bn_eval)
Now you can train your model in the same manner as you were. BN layers won’t be trained.
After intermediate validation where you might have used model.eval() then switch to training via model.train() follow that with model.apply(set_bn_eval).