Too many values to unpack from tensor

I think your previous (deleted) approach is ok. Can you try this?

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return 

model.head = Identity()

See this answer .