Using TransformerEncoderLayer for image classification

my goal is to use a pretrained 2d cnn for feature extraction followed by transformer for temporelle information

import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torchvision.models import ResNet18_Weights

class ResNetTransformer(nn.Module):
    def __init__(self, input_shape, num_action_classes, num_condition_classes, num_heads, ff_dim, num_layers):
        super(ResNetTransformer, self).__init__()
        self.resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.resnet18 = nn.Sequential(*list(self.resnet18.children())[:-1]) 
        self.feature_dim = 512
 #fc layer removed as i need only the features 
        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(d_model=self.feature_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True),
            num_layers=num_layers
        )
        
        # as i have two lables to predict i replaced it with two prediction head 
        self.action_head = nn.Linear(self.feature_dim, num_action_classes)
        self.condition_head = nn.Linear(self.feature_dim, num_condition_classes)

    def forward(self, x, mask):
        batch_size, seq_len, c, h, w = x.shape
        x = x.view(batch_size * seq_len, c, h, w) 
        features = self.resnet18(x)  
        features = features.view(batch_size, seq_len, -1)  
              # here as i used the input as a sequence and not sequences have the same length i used padding in a costum collate_fn function and than mask for the transformer to ignore it 

        temporal_features = self.transformer(features, src_key_padding_mask=~mask)       
 
        pooled_features = temporal_features.mean(dim=1)  
        
        action_pred = self.action_head(pooled_features)
        condition_pred = self.condition_head(pooled_features)
        
        return action_pred, condition_pred

is this the right choice to implement to perform my multi label classification ?
as for the input as i have multiple subject each subjet have action_label (they perform multiple actions ) and condition label ,i used sequences of actions frames as input (consider one action as sequence )and used a costum collate_fn

def collate_fn(batch):
  
    frames, action_labels, condition_labels = zip(*batch)
    lengths = [len(seq) for seq in frames]
    max_length = max(lengths)
    
    padded_frames = torch.zeros(len(frames), max_length, *frames[0].shape[1:])  # (batch_size, max_length, C, H, W)
    for i, seq in enumerate(frames):
        padded_frames[i, :lengths[i]] = seq
   
    action_labels = torch.tensor(action_labels)
    condition_labels = torch.tensor(condition_labels)
    lengths = torch.tensor(lengths)
    
    return padded_frames, action_labels, condition_labels, lengths

and then calle it in

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

anything else to consider ?