Extract features doesn't work

i need to extract the feature by ignoring the layer of classification … so I tried this code but I think it is wrong

class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

PATH = 'model.pth'
model = Model()
model.load_state_dict(torch.load(PATH),strict=False)
model.head = Identity()

the structure of the model is

       (1): TransformerBlock(
          dim=768, input_resolution=(7, 7), num_heads=24, window_size=7, shift_size=0, mlp_ratio=4.0
          (pool_layers): ModuleList()
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=768, window_size=(7, 7), num_heads=24
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (avgpool): AdaptiveAvgPool1d(output_size=1)
  (head): Identity()

I couldn’t see any obvious pointer to determine this code as wrong. How did you find out that this code is wrong?

I got bad numbers in a loss in validation and training .
the code of the model is

def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

I think this code is doing what you intend to do (i.e., extract features without classification layer).
Bad numbers might have possibly a different reason than this code I guess.

I think I’m in the wrong way as this problem I met before and the cause of it from the features itself that I extracted it