EEG classification using PyTorch (LSTM + CNN combined feature extraction)

I am trying to classify time series EEG signals for imagined motor actions using PyTorch. The dataset I’m using is the eegmmmidb dataset. I am using mne to get the events from data. The shape of input data = [batch_size, number of channels (electrodes), timestep (160 sampling rate) which comes out to [batch_size, 64, 161 for a batch of events.

I want to construct a neural network which passes the data through both LSTM and CNN, extracting temporal features usin LSTM and spacial features using CNN and then use the combined feature map to get a classified output.

If I am correct, the features are the 64 channels (or EEG electrodes) and the timestep.

I need help with the dimensioning and layering of the model as I am not gettng a good accuracy for the testing set.

Here’s what I’ve done. Accuracy of 63% for this model with a dataset with around 30000 events data, which is not up to the standard required -

modelTrain.py -

from sklearn.preprocessing import StandardScaler
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn, optim
from sklearn.metrics import accuracy_score

class NeuralNetwork(nn.Module):
    def __init__(self, num_classes, seq_len=161):
        super().__init__()
        
        self.seq_len = seq_len
        
        self.conv1 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(5120, 512)
        self.fc2 = nn.Linear(512, 120)
        self.drop = nn.Dropout(0.5)
        
        self.rnn = nn.LSTM(input_size=128, hidden_size=256, num_layers=5, batch_first=True)
        
        self.encoder = nn.Linear(376, 200)
        self.decoder = nn.Linear(200, 376)
        
        self.fc3 = nn.Linear(376, 100)
        self.fc4 = nn.Linear(100, num_classes)
        
    def forward(self, x):
        batch_size = x.size(0)
        x_cnn = self.pool1(self.drop(torch.relu(self.conv1(x))))
        x_cnn = self.pool2(self.drop(torch.relu(self.conv2(x_cnn))))
        cnn_feature = x_cnn.view(batch_size, -1)
        cnn_feature = self.fc1(cnn_feature)
        cnn_feature = self.fc2(cnn_feature)
        x_rnn = x_cnn.permute(0, 2, 1)
        rnn_out, _ = self.rnn(x_rnn)
        rnn_feature = rnn_out[:, -1, :] 
        combined_feature = torch.cat((rnn_feature, cnn_feature), dim=1)
        encoded = torch.relu(self.encoder(combined_feature))
        decoded = torch.relu(self.decoder(encoded))
        output = self.fc3(decoded)
        output = self.fc4(output)
        return output
    
model = NeuralNetwork(3)
loss_fn = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=1e-3)

dataset_data = np.load('eeg_dataset_data.npy')
dataset_labels = np.load('eeg_dataset_labels.npy') - 1

train_size = int(0.8 * len(dataset_data))
test_size = len(dataset_data) - train_size

train_data = dataset_data[:train_size]
train_labels = dataset_labels[:train_size]

test_data = dataset_data[train_size:]
test_labels = dataset_labels[train_size:]

scaler = StandardScaler()

train_data_scaled = scaler.fit_transform(train_data.reshape(train_data.shape[0], -1))
train_data_scaled = train_data_scaled.reshape(train_data.shape)

test_data_scaled = scaler.transform(test_data.reshape(test_data.shape[0], -1))
test_data_scaled = test_data_scaled.reshape(test_data.shape)

train_data_tensor = torch.tensor(train_data_scaled, dtype=torch.float32)
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)

test_data_tensor = torch.tensor(test_data_scaled, dtype=torch.float32)
test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)

train_dataset = TensorDataset(train_data_tensor, train_labels_tensor)
test_dataset = TensorDataset(test_data_tensor, test_labels_tensor)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for batch_data, batch_labels in train_loader:
    print("Batch data shape:", batch_data.shape)
    print("Batch labels shape:", batch_labels.shape)
    break

epochs = 10

for epoch in range(epochs):
    model.train()
    running_loss = 0
    for batch_data, batch_labels in train_loader:
        optimiser.zero_grad()
        outputs = model(batch_data)
        loss = loss_fn(outputs, batch_labels)
        loss.backward()
        optimiser.step()
        running_loss += loss.item()

    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for test_data, test_labels in test_loader:
            outputs = model(test_data)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.numpy())
            all_labels.extend(test_labels.numpy())

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch + 1}/{epochs} - Loss: {running_loss / len(train_loader):.4f} - Accuracy: {accuracy:.4f}")

getData.py

import os
import mne
import numpy as np
import torch

def preprocess_eeg_file(edf_file):
    print(f"Processing file: {edf_file}")
    raw = mne.io.read_raw_edf(edf_file, preload=True)
    print(f"Loaded raw data: {raw}")
    
    events, events_id = mne.events_from_annotations(raw)
    print(f"Events found: {events.shape[0]}")
    
    raw.load_data()
    raw.filter(1., 40., fir_design='firwin')
    print("Data filtered between 1 and 40 Hz")
    
    epochs = mne.Epochs(raw, events, event_id=events_id, tmin=-0.2, tmax=0.8, baseline=(None, 0), preload=True)
    epochs.drop_bad()
    print(f"Epochs created and bad epochs dropped. Remaining epochs: {len(epochs)}")
    
    data = epochs.get_data()
    labels = epochs.events[:, -1]
    print(f"Data shape: {data.shape}, Labels shape: {labels.shape}")
    return data, labels

def process_subfolder(subfolder_path):
    print(f"Processing subfolder: {subfolder_path}")
    all_data = []
    all_labels = []
    
    for filename in os.listdir(subfolder_path):
        if filename.endswith('.edf'):
            edf_file = os.path.join(subfolder_path, filename)
            data, labels = preprocess_eeg_file(edf_file)
            all_data.append(data)
            all_labels.append(labels)
    
    print(f"Finished processing subfolder: {subfolder_path}")
    return np.concatenate(all_data, axis=0), np.concatenate(all_labels, axis=0)

main_folder = 'dataset\\files'

dataset_data = []
dataset_labels = []

print(os.listdir('dataset\\files'))
for subfolder in sorted(os.listdir(main_folder)):
    print("In the data folder")
    subfolder_path = os.path.join(main_folder, subfolder)
    if os.path.isdir(subfolder_path):
        print(f'Starting to process subfolder: {subfolder}')
        data, labels = process_subfolder(subfolder_path)
        dataset_data.append(data)
        dataset_labels.append(labels)
        print(f"Processed subfolder: {subfolder}")

dataset_data = np.concatenate(dataset_data, axis=0)
dataset_labels = np.concatenate(dataset_labels, axis=0)

print(f"Total dataset size: {dataset_data.shape}, Total labels size: {dataset_labels.shape}")

np.save('eeg_dataset_data.npy', dataset_data)
np.save('eeg_dataset_labels.npy', dataset_labels)

print(f"Dataset saved. Total samples: {len(dataset_labels)}")