PyTorch model not converging in a CNN + LSTM classification task

I am training a PyTorch model to classify spectrograms of audio signals into two classes (normal, abnormal) using a CNN followed by an LSTM. The CNN is used to extract time features, and the LSTM is used to classify the spectrograms. The input shape to the LSTM is [batch_size, 64, seq_length] which is then permuted to [batch_size, seq_length, 64].

The loss curve is like this:

Here is my model definition:

class CNN(nn.Module):
    def __init__(self, input_dim):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_dim, 16, 3), nn.ReLU(), nn.MaxPool2d(2), nn.BatchNorm2d(16)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3), nn.ReLU(), nn.MaxPool2d(4), nn.BatchNorm2d(32)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3), nn.ReLU(), nn.MaxPool2d(2), nn.BatchNorm2d(64)
        )
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, (2, 1)))

        self.Initialize_weights()

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        return out

    def Initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform(m.weight)
                nn.init.constant_(m.bias, 0)


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size=64, n_layers=1, device="cuda:0"):
        super(RNN, self).__init__()
        self.device = device
        self.n_layers = n_layers
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True)
        self.flatten = nn.Flatten()

    def forward(self, x):
        # [batch_size, features, 1, seq_length]
        out = x.squeeze(dim=2)
        # [batch_size, features, seq_length]
        out = out.permute(0, 2, 1)
        hidden_states = torch.zeros(self.n_layers, out.size(0), self.hidden_size).to(
            self.device
        )
        cell_states = torch.zeros(self.n_layers, out.size(0), self.hidden_size).to(
            self.device
        )
        # [batch_size, seq_length, features]
        out, _ = self.lstm(out, (hidden_states, cell_states))
        out = self.flatten(out[:, -1, :])
        return out

    def Initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.LSTM):
                nn.init.kaiming_uniform(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


class SCRNN(nn.Module):
    def __init__(
        self,
        input_dim,
        input_size,
        output_dim,
        device="cuda:0",
        n_layers_rnn=64,
    ):
        super(SCRNN, self).__init__()
        self.cnn = CNN(input_dim)
        self.rnn = RNN(input_size, 64, n_layers_rnn, device=device)
        self.fc1 = nn.Linear(64, 32)
        self.relu1 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(32, output_dim)

    def forward(self, x):
        cnn_out = self.cnn(x)  # [64, 64, 1, 8]
        rnn_out = self.rnn(cnn_out)  # [64, 64]
        out = self.fc1(rnn_out)  # [64, 32]
        out = self.relu1(out)
        out = self.dropout(out)
        out = self.fc2(out)  # [64, 2]
        return out

Here is my training loop:

def training(model, train_dl, val_dl, N_EPOCHS):
    # Loss function, Optimizer and Scheduler
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=MAX_LR, steps_per_epoch=int(len(train_dl)), epochs=N_EPOCHS, anneal_strategy='linear')

    # metrics for visualization
    acc_fn = BinaryAccuracy().to(device)
    precision_fn = BinaryPrecision().to(device)
    speci_fn = BinarySpecificity().to(device)
    f1_fn = BinaryF1Score().to(device)
    
    best_val_acc = 0
    best_model_params = model.state_dict()
    epochs_since_improvement  = 0
    # Repeat for each epoch
    for epoch in range(N_EPOCHS):
        running_loss = 0.0
        correct_prediction = 0
        total_prediction = 0
        # Train step
        for i, data in enumerate(train_dl):
            # Get the input features and target labels and put them on the GPU
            inputs, labels = data[0].to(device), data[1].to(device)
            # Normalize the inputs
            inputs_m, inputs_s = inputs.mean(), inputs.std()
            inputs = (inputs - inputs_m) / inputs_s

            optimizer.zero_grad()

            # forward + backward + optimize
            model.train()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            #keep stats for loss and accuracy
            running_loss += loss.item()

            # get the predicted class with the highest score
            _, prediction = torch.max(outputs, 1)
            # count of predictions that matched the target label
            correct_prediction += (prediction == labels).sum().item()
            total_prediction += prediction.shape[0]

        # Validation step
        val_running_loss = 0
        model.eval()
        with torch.inference_mode():
            for _, val_data in enumerate(val_dl):
                val_inputs, val_labels = val_data[0].to(device), val_data[1].to(device)
                # Normalize the inputs
                val_inputs_m, val_inputs_s = val_inputs.mean(), val_inputs.std()
                val_inputs = (val_inputs - val_inputs_m) / val_inputs_s

                val_outputs = model(val_inputs)
                _, val_prediction = torch.max(val_outputs, 1)
                
                val_loss = loss_fn(val_outputs, val_labels)
                val_running_loss += val_loss.item()
                acc_fn.update(val_prediction, val_labels)
                precision_fn.update(val_prediction, val_labels)
                speci_fn.update(val_prediction, val_labels)
                f1_fn.update(val_prediction, val_labels)
        
        # Print stats at the end of the epoch
        num_batches = len(train_dl)
        avg_loss = running_loss / num_batches
        acc = correct_prediction / total_prediction
        
        val_num_batches = len(val_dl)
        val_avg_loss = val_running_loss / val_num_batches
        val_acc = acc_fn.compute().item()
        precision = precision_fn.compute().item()
        speci = speci_fn.compute().item()
        f1 = f1_fn.compute().item()

        if val_acc > best_val_acc:
            # check for the previous max accuracy value
            best_val_acc = val_acc
            best_model_params = model.state_dict()
            epochs_since_improvement = 0

        elif epochs_since_improvement > STOP_EPOCHS:
            # stop early if validation accuracy doesnt improve for 'STOP_EPOCHS' epochs
            print(f"{bcolors.WARNING}Training stopped as validation accuracy did not improve in the last {STOP_EPOCHS} epochs.{bcolors.ENDC}")
            break
        else:
            epochs_since_improvement += 1

        if epoch % 1 == 0:
            print(f'{bcolors.OKBLUE}Epoch: {epoch}, Train Loss: {avg_loss:.2f}, Validation Loss: {val_avg_loss:.2f}, Train Accuracy: {acc:.2f}, Validation Accuracy: {val_acc:.2f}{bcolors.ENDC}')
            # save metrics to a .csv file
            update_metrics(MODEL_NAME, val_acc, val_avg_loss, acc, avg_loss, precision, speci, f1 )
        if epoch % (N_EPOCHS / 10) == 0:
            # save model params and optimizer internal state every 10% epochs
            MODEL_NAME_PTH = MODEL_NAME + '_' + str(val_acc) + '.pth'
            OPTIM_PTH = 'optim_' + MODEL_NAME + '_' + str(val_acc) + '.pth'
            MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME_PTH
            OPTIM_SAVE_PATH = MODEL_PATH / OPTIM_PTH
            print(f"Saving model to {MODEL_SAVE_PATH}")
            torch.save(obj=myModel.state_dict(), f=MODEL_SAVE_PATH)
            torch.save(obj=optimizer.state_dict(), f=OPTIM_SAVE_PATH)
    
    print(f"{bcolors.OKGREEN}Finished Training{bcolors.ENDC}")
    return best_model_params

I have tried multiple models before, and most of them converged. I tried using a standalone CNN, and a parallel CNN and LSTM, then concatenated them. (I also tried LSTM alone and it didn’t work) However, the current CNN + LSTM model does not converge. When I test the model on a dummy input, it works fine and outputs the expected shapes with no problem. The only problem is that the model does not converge, and the training loss gets stuck at 0.69.

I also tried using the time distributed function similar in this https://stackoverflow.com/a/66955673/19544089 and the results were the same.