How to remove a prediction head from the model?

I am working on a ViT (Vision Transformer) related project and some low level definition is deep inside timm library, which I can not change. The low level library definition involves a linear classification prediction head, which is not a part of my network.

Every thing was fine until I switched to DDP parallel implementation. Pytorch complained about some parameters which didn’t contribute to the loss, and it instructed me to use “find_unused_parameters=True”. In fact, it is a common scenario and it worked again if I added this “find_unused_parameters=True” to the training routine. However, I am only allowed to change the model definition in our code base, but I cannot modify anything related to training …

So I guess the only thing I can do right now, is to “remove” the linear head from the model.
Although I cannot dig into the low level definition of ViT, but I can output this tensor like this:

encoder_output,   linear_head_output =  ViT(input)

Is it possible to remove this linear prediction head based on this linear_head_output tensor?

The easiest could be to just assign a nn.Identity() module on the prediction head.
This lets the model call into the prediction head as desired, but just returns the input.

Best regards

Thomas

1 Like

@tom Thanks for your suggestion. My current work around is like this:
encoder_output = encoder_output + 0*linear_head_output
It won’t affect the calculation of loss based on encoder_ouput, while making linear_head_output “used”. But it is obviously not an elegent way to fix this, and I am not sure if it also affect runtime.

Could you please provide some exemplary instruction how to make use of nn.Identity() in this case? I got a feeling that it would be a better fix than mine, but I just don’t know to implement that.

Basically it would just be something like (I don’t know what exactly the submodule is that is causing you trouble, so I’'m calling it linear_head here.)

        model.linear_head = torch.nn.Identity()

I have an example of this (albeit with my own BERT implementation) in the tests of my (educational + WIP) Toroidal transformer library.

Best regards

Thomas

1 Like