I think your previous (deleted) approach is ok. Can you try this?
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return
model.head = Identity()
See this answer .