I am using a 2D CNN + Transformer for action classification task ,my idea was to use the CNN for extracting spatial information and then pass them to th transformer to capture temporal transformation but when training the model i got very poor loss and accuracy for both training and validation never passed 25% for both.
this is how I defined my model ,
class CNN2DTransformer(nn.Module):
def __init__(self, num_action_classes, num_condition_classes, d_model=512, nhead=8, num_layers=6):
"""
Args:
num_action_classes (int): Number of action classes.
num_condition_classes (int): Number of condition classes.
d_model (int): Transformer embedding dimension.
nhead (int): Number of heads in the multi-head attention mechanism.
num_layers (int): Number of Transformer Encoder layers.
"""
super(CNN2DTransformer, self).__init__()
# 1. Backbone CNN (ResNet18)
self.cnn = resnet50(weights=ResNet50_Weights.DEFAULT)
self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive pooling to get a fixed-size feature vector
self.cnn_feature_dim = 2048 # ResNet-50 outputs 2048 feature dimensions
# 2. Transformer Encoder
self.embedding = nn.Linear(self.cnn_feature_dim, d_model) # Project CNN features into Transformer dimension
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=2048, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 3. Classification heads
self.action_head = nn.Linear(d_model, num_action_classes)
self.condition_head = nn.Linear(d_model, num_condition_classes)
def forward(self, frames):
batch_size, seq_len, c, h, w = frames.size()
frames = frames.view(batch_size * seq_len, c, h, w)
cnn_features = self.cnn(frames)
cnn_features = self.pool(cnn_features).squeeze(-1).squeeze(-1)
cnn_features = cnn_features.view(batch_size, seq_len, -1) # (batch_size, seq_len, 2048)
transformer_input = self.embedding(cnn_features) # (batch_size, seq_len, d_model)
transformer_output = self.transformer_encoder(transformer_input) # (batch_size, seq_len, d_model)
global_features = transformer_output.mean(dim=1) # (batch_size, d_model)
action_logits = self.action_head(global_features) # (batch_size, num_action_classes)
condition_logits = self.condition_head(global_features) # (batch_size, num_condition_classes)
return action_logits, condition_logits
and this is how I defined my dataset
class ActionDataset(Dataset):
self.data = pd.read_csv(csv_file)
self.transform = transform
self.num_frames = num_frames
# Dictionnaire {action_id : [(image_path, condition)]}
self.action_dict = self._group_by_action()
# Mapping des labels en indices
self.action_classes = sorted(self.data["action_name"].unique())
self.condition_classes = sorted(self.data["condition"].unique())
self.action_to_idx = {c: i for i, c in enumerate(self.action_classes)}
self.condition_to_idx = {c: i for i, c in enumerate(self.condition_classes)}
def _group_by_action(self):
action_dict = {}
for _, row in self.data.iterrows():
action_id = f"{row['patient_id']}_{row['action_name']}"
if action_id not in action_dict:
action_dict[action_id] = []
action_dict[action_id].append((row["image_path"], row["condition"]))
return action_dict
def _sample_frames(self, frames):
if len(frames) >= self.num_frames:
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
else:
indices = list(range(len(frames))) + [len(frames) - 1] * (self.num_frames - len(frames))
return [frames[i] for i in indices]
def __len__(self):
return len(self.action_dict)
def __getitem__(self, idx):
action_id = list(self.action_dict.keys())[idx]
frames = self.action_dict[action_id]
sampled_frames = self._sample_frames(frames)
images = []
for img_path, _ in sampled_frames:
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
image = self.transform(image)
images.append(image)
images = torch.stack(images)
action_label = self.action_to_idx[action_id.split("_")[1]] # Action Name
condition_label = self.condition_to_idx[sampled_frames[0][1]] # Condition (CN, MCI, AD)
return images, torch.tensor(action_label), torch.tensor(condition_label)
i used a sample function to get only a 32 frame sequence from the entire images (I had to groupe them by patient_action to get each sequence )as I understand the transformer takes as input a sequence of frames but I think this input is not right for the 2D CNN anyone have a suggestion