my goal is to use a pretrained 2d cnn for feature extraction followed by transformer for temporelle information
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torchvision.models import ResNet18_Weights
class ResNetTransformer(nn.Module):
def __init__(self, input_shape, num_action_classes, num_condition_classes, num_heads, ff_dim, num_layers):
super(ResNetTransformer, self).__init__()
self.resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)
self.resnet18 = nn.Sequential(*list(self.resnet18.children())[:-1])
self.feature_dim = 512
#fc layer removed as i need only the features
self.transformer = TransformerEncoder(
TransformerEncoderLayer(d_model=self.feature_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True),
num_layers=num_layers
)
# as i have two lables to predict i replaced it with two prediction head
self.action_head = nn.Linear(self.feature_dim, num_action_classes)
self.condition_head = nn.Linear(self.feature_dim, num_condition_classes)
def forward(self, x, mask):
batch_size, seq_len, c, h, w = x.shape
x = x.view(batch_size * seq_len, c, h, w)
features = self.resnet18(x)
features = features.view(batch_size, seq_len, -1)
# here as i used the input as a sequence and not sequences have the same length i used padding in a costum collate_fn function and than mask for the transformer to ignore it
temporal_features = self.transformer(features, src_key_padding_mask=~mask)
pooled_features = temporal_features.mean(dim=1)
action_pred = self.action_head(pooled_features)
condition_pred = self.condition_head(pooled_features)
return action_pred, condition_pred
is this the right choice to implement to perform my multi label classification ?
as for the input as i have multiple subject each subjet have action_label (they perform multiple actions ) and condition label ,i used sequences of actions frames as input (consider one action as sequence )and used a costum collate_fn
def collate_fn(batch):
frames, action_labels, condition_labels = zip(*batch)
lengths = [len(seq) for seq in frames]
max_length = max(lengths)
padded_frames = torch.zeros(len(frames), max_length, *frames[0].shape[1:]) # (batch_size, max_length, C, H, W)
for i, seq in enumerate(frames):
padded_frames[i, :lengths[i]] = seq
action_labels = torch.tensor(action_labels)
condition_labels = torch.tensor(condition_labels)
lengths = torch.tensor(lengths)
return padded_frames, action_labels, condition_labels, lengths
and then calle it in
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
anything else to consider ?