Why does training in Google Colab takes forever?

I want to train my model that is fusion of pretrained 1d CNN network (loaded from checkpoint) and pretrained pytorch resnet (x3d_s).

At the moment the highest batch size that I can use is 8. Otherwise I am getting error that CUDA is out of memory. As a result the training of one epoch takes 5 minutes.

The number of learning parameters is 21 373 968. Is that way too much? What could I do to speed up the training. My dataset constains 700 videos with 45 frames (256x256) and 8 classes.

Is there maybe any other mistake in my code that makes training too slow?

Here some parts of code:

Model A

class ModelA(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels = 51, out_channels = 64, kernel_size = 3, padding=1) 
        self.batch1 = nn.BatchNorm1d(64)

        self.dilated_conv2 = nn.Conv1d(in_feat, out_feat, kernel_size = 3, stride = stride, dilation = dilation, padding = 3) 
        self.conv_transform2 = nn.Conv1d(out_feat, out_feat, kernel_size = 3, padding = 1) 
        self.batch2 = nn.BatchNorm1d(128)

        self.dilated_conv3 = nn.Conv1d(in_feat, out_feat, kernel_size = 3, stride = stride, dilation = dilation, padding = 3)
        self.conv_transform3 = nn.Conv1d(out_feat, out_feat, kernel_size = 3, padding = 1)
        self.batch3 = nn.BatchNorm1d(256)

Model B

class ModelB(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_s', pretrained=True)

        for param in self.model.parameters():
            param.requires_grad = False

        self.model.blocks[5] = nn.Sequential()
        self.conv1 = nn.Conv1d(in_channels = 12288, out_channels = 1024, kernel_size = 1)
        self.conv2 = nn.Conv1d(in_channels = 1024, out_channels = 256, kernel_size = 1)

Fused model

class FusedModel(nn.Module):
    def __init__(self, modelA, modelB):
        super().__init__()

        self.modelA = modelA
        self.modelB = modelB
        self.lstm_extractor = nn.LSTM(input_size=512, hidden_size=512, num_layers=2, dropout = 0.2, batch_first=True)
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 8)

    def forward(self, x1, x2):
        x1 = self.modelA(x1) 
        x2 = self.modelB(x2) 
        x1 = x1.permute(0, 2, 1) 
        x = torch.concatenate([x1, x2], dim=2) 
        out, (ht, ct) = self.lstm_extractor(x) 
        out = ht[-1] 
        out = self.flat(out)

        out = self.fc1(out)
        out = self.fc2(out)
        return out
modelA = ModelA().float()
modelB = ModelB().float()

modelA.load_state_dict(checkpoint['model_state_dict'])

for param in modelB.model.blocks[5].parameters():
    param.requires_grad = True
for param in modelB.conv1.parameters():
    param.requires_grad = True
for param in modelB.conv2.parameters():
    param.requires_grad = True
criterion = nn.CrossEntropyLoss() 

optimizer = torch.optim.Adam(model.parameters(), lr = lr)

Is there anything that can be done here to improve the speed or is the model too complex and dataset too big? Is there any faster faster resnet other than ‘x3d_s’?

Here the first three epochs. I am concerned about unstable validation loss and validation accuracy. Is the reason for that too small batch?

Edit: after few additional epochs I see my network doesnt really learn. Can also batch size be reason for that?

Thank you!