Partial transfer learning EfficientNet

Hi, can someone ping me to an example of partial transfer learning with EfficientNet? For instance, unfreezing the last two blocks of the network?

Thanks in advance!

You can freeze all parameters of the model first via:

for param in model.parameters():

and later unfreeze the desired blocks by printing the model (via print(model)) and use the corresponding module names to unfreeze their parameters. E.g. assuming the last two “blocks” are named block1 and block2, this should work:

for param in model.block1.parameters():
for param in model.block2.parameters():
1 Like

Indeed, thanks for the suggestion. A bit granular but works!

def get_model(model_name):
    model = timm.create_model(model_name, pretrained=True)
    return model

model = get_model('efficientnet_b0')
for name, _ in model.named_children():

for param in model.parameters():

blocks_to_retrain = 1
count = 0
for name, block in model.blocks.named_children():
    if count >= (len(model.blocks) - blocks_to_retrain):
        print(f'Unfreeze block {name}')
        for pname, params in block.named_parameters():
            if 'bn' not in pname:
                params.requires_grad = True
        print(f'Keep block {name} frozen')
    count += 1

for name, child in model.named_children():
    if name in ['conv_head', 'act2', 'global_pool', 'classifier']:
        print(f'Unfreeze block {name}')
        for params in child.parameters():
            params.requires_grad == True

Here an example of partial transfer learning

I use EfficientNet as feature extractor and only train the las layer. The training works however at each epoch i get this warning

UserWarning: An output with one or more elements was resized since it had shape [1204224], which does not match the required output shape [8, 3, 224, 224].This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at /opt/conda/conda-bld/pytorch_1641456362397/work/aten/src/ATen/native/Resize.cpp:24.)
return torch.stack(batch, 0, out=out)

Does anyone know what it is?