Hybrid CNN + ViT approach

Hello, I would like to create hybrid arch CNN + ViT image classification model.

What I wanna do:

  1. Extract features from CNN i.e: pretrained EfficientNet_B3
  2. Pass features to (custom) ViT as input of patch embedding
  3. Perform classification

→ And my question is:
Whether my approach is appropriate:

code:
Ad1. take last features Conv2NormActivision from EfficenitNet

│    └─Conv2dNormActivation (8)                   [32, 384, 7, 7]      [32, 1536, 7, 7]              
│    │    └─Conv2d (0)                                     [32, 384, 7, 7]      [32, 1536, 7, 7]        
│    │    └─BatchNorm2d (1)                           [32, 1536, 7, 7]     [32, 1536, 7, 7]       
│    │    └─SiLU (2)                                         [32, 1536, 7, 7]     [32, 1536, 7, 7]

so I will have [32, 1536, 7, 7]

then
Ad2. pass to custom ViT

patch_embedding_layer = PatchEmbedding(in_channels=1536,
                                        patch_size=1,
                                        embedding_dim=768)

class PatchEmbedding(nn.Module):
    def __init__(self,
                 in_channels: int = 3,
                 patch_size: int = 16,
                 embedding_dim: int = 768):
        super().__init__()
        self.patcher = nn.Conv2d(in_channels=in_channels,
                                 out_channels=embedding_dim,
                                 kernel_size=patch_size,
                                 stride=patch_size,
                                 padding=0)
        self.flatten = nn.Flatten(start_dim=2,  end_dim=3)
    def forward(self, x):
        image_resolution = x.shape[-1]
        assert image_resolution % patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}"
        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched)
        return x_flattened.permute(0, 2, 1) 

Create new Model based on above:

class EffC(nn.Module):
    def __init__(self) -> None:
        super(EffC, self).__init__()

        weights = torchvision.models.EfficientNet_B3_Weights.DEFAULT
        model = torchvision.models.efficientnet_b3(weights=weights).to(device)
        for param in model.features.parameters():
            param.requires_grad = False
        self.modelEF = model
        self.vit = ViT(num_classes=len(class_names))

    def forward(self, x):
        x = self.modelEF.features(x)
        x = self.vit(x)
        return x