Training a classifier using "Online training"

I have thousands of sequences of data which are time series (Each sequence is between 500-20000 samples). I want to apply online training as in https://en.wikipedia.org/wiki/Online_machine_learning .

The way i do it now;

I split my long sequence into many small sequences and train/optimize on each window:

The problem:
Since i calculate loss and optimize for each data sample in the long sequence the model learns to only classify that class. Lets say the sequence belong to class 1, then after 1000 classifications with class 1 the model only classify class 1.

Then when the model starts training on sequence number two it will still only predict class 1, even if sequence two belongs to class 3. After some number of data samples it starts to only classify class 3 etc.

This continues on and on. No matter how many epochs i train on. The model will only classify whatever it was trained on last. The problem is that the model gets thousands or hundreds of examples in a row that are same class.

In peudo-code it looks something like this:

#loop all training sequences:
for step, (data, target) in enumerate(train_set_all):
        X = data 
        y = target
        
        #Split sequence into windows:
        windows = []
        for start_pos in range(0, window_size, stride_length):
            end_pos = start_pos + batch_size
            window = np.copy(X[:,:,start_pos:end_pos])
            windows.append(torch.from_numpy(batch))
        
        for idx, nWindow in enumerate(windows):
                output = model(nWindow)
                loss = torch.mean(criterion(torch.cat(output),torch.cat(y)))
                losses.append(loss.data[0].item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
               

I clip the sequence into many windows and train/optimize on each window.

What i have tried:

I tried saving the loss and only optimize after some number of windows(fx. 100), however accumulating loss takes up a lot of memory and significantly decrease training speed. Since calling loss.backward() after lets say 100 samples takes ages to calculate.

Is there any other options?

Regards

If anyone know of examples i can look into or tutorials that is very appreciated as well. Since i only get deep learning courses when i google “Pytorch Online Training” :slight_smile: Which is not what im looking for.

Thanks