Why does IterableDataset and DataLoader cause issues with training metrics?

I’m using a custom IterableDataset and DataLoader to train my model due to the large amount of data - since I can’t store it all in memory at once.

My data consists of multiple files such that each file has thousands and thousands of examples.
I have set up my training loop and found that whenever I run an epoch I get the same training accuracy and training loss for every epoch. The same also applies for the validation. Why does that happen?

I’m starting to think that the issue is with the shuffling of the dataset. I do know that the DataLoader does automatic shuffling but this doesn’t apply for IterableDatasets. So to accomodate for that, I’ve created a BatchShuffleDataSet which is also a custom iterabledataset for batch shuffling.

This is my DatasetClass:

class ProtobufIterableDataset(IterableDataset):
    def _init_(self, data_folder, scaling=False):
        self.data_folder = data_folder
        self.scaling = scaling

        super()._init_()
        
        if scaling:
            scaler_path = "scalers.npz"
            # Load the data once and store as class attribute instead of instance attribute to avoid BadZipFile error
            if not hasattr(ProtobufIterableDataset, 'scaler_data'):
                ProtobufIterableDataset.scaler_data = dict(np.load(scaler_path))

  def _iter_(self):
          # Worker Split Based on Number of Examples in Each File so sharding at example level
          worker_info = torch.utils.data.get_worker_info()
          
          if worker_info is not None:
              worker_id = worker_info.id
              num_workers = worker_info.num_workers
  
              # Get list of all files
              for file in sorted(os.listdir(self.data_folder)):
                  with open(os.path.join(self.data_folder, file), 'rb') as f:
                      cache_data = cache_state_pb2.CacheStateProto()
                      cache_data.ParseFromString(f.read())
                      
                      # Each worker takes every nth example
                      for i, example in enumerate(cache_data.proto_examples):
                          if i % num_workers == worker_id:
                              processed_example = self.process(example)
                              yield processed_example
          else:
              print("Single Worker Mode")
              for file in os.listdir(self.data_folder):
                  with open(os.path.join(self.data_folder, file), 'rb') as f:
                      cache_data = cache_state_pb2.CacheStateProto()
                      cache_data.ParseFromString(f.read())
                      
                      for example in cache_data.proto_examples:
                          processed_example = self.process(example)
                          yield processed_example

This is my Batch Shuffler Class:

# Batch Shuffler Class
class BatchShuffleDataset(IterableDataset):
    def _init_(self, dataset, batch_size):
        super()._init_()
        self.dataset = dataset
        self.batch_size = batch_size

    def _iter_(self):
        batch = []
        dataset_iter = iter(self.dataset)
        for item in dataset_iter:
            batch.append(item)
            if len(batch) == self.batch_size:
                random.shuffle(batch)
                for shuffled_item in batch:
                    yield shuffled_item
                batch = []
        if batch:
            random.shuffle(batch)
            for shuffled_item in batch:
                yield shuffled_item

Here is my training file:

import torch
from torch.utils.data import DataLoader
from new_data_loader import ProtobufIterableDataset, BatchShuffleDataset
from model import SimpleModel
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

# Constants
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
NUM_WORKERS = 4

# Load Training Data
trainingDataset = ProtobufIterableDataset('training_examples', scaling=True)
trainingDataset = BatchShuffleDataset(trainingDataset, batch_size=SHUFFLE_BUFFER_SIZE)
trainingDataLoader = DataLoader(trainingDataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Load Validation Data
validationDataset = ProtobufIterableDataset('validation_examples', scaling=True)
validationDataset = BatchShuffleDataset(validationDataset, batch_size=SHUFFLE_BUFFER_SIZE)
validationDataLoader = DataLoader(validationDataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel().to(device)
    

epochs = 10
learning_rate = 0.001
crossEntropyLoss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


def train(model, trainingDataLoader, validationDataLoader, optimizer, criterion, device, epochs=10):
    """
    Training function with improved metric calculations, compatible with IterableDataset
    """
    # Create outer progress bar for epochs
    epoch_pbar = tqdm(range(epochs), desc='Training Progress', position=0)
    
    for epoch in epoch_pbar:
        # Training Phase
        model.train()
        running_train_loss = 0.0    # Running loss for the current epoch
        epoch_train_correct = 0     # Total correct predictions for the epoch
        epoch_train_samples = 0     # Total samples processed in the epoch
        
        # Create progress bar for training batches
        train_pbar = tqdm(trainingDataLoader, 
                         desc=f'Training Epoch {epoch+1}/{epochs}',
                         position=1, 
                         leave=False)
        
        # For every batch
        for batch_idx, (set_features, access_features, cache_features, labels) in enumerate(train_pbar, 1):
            # Combine features
            cache_features_flat = cache_features.reshape(-1, 17*9)
            combined_features = torch.cat([set_features, access_features, cache_features_flat], dim=1)
            
            # Move to device
            combined_features = combined_features.to(device)
            labels = labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(combined_features)
            loss = criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Update running loss
            batch_size = labels.size(0)
            running_train_loss += loss.item() * batch_size
            
            # Calculate accuracy
            predicted = outputs.argmax(dim=1)
            true_labels = labels.argmax(dim=1)
            batch_correct = (predicted == true_labels).sum().item()
            
            # Update totals
            epoch_train_samples += batch_size
            epoch_train_correct += batch_correct
            
            # Calculate current accuracy
            current_accuracy = 100 * epoch_train_correct / epoch_train_samples
            current_loss = running_train_loss / epoch_train_samples
            
            # Update training progress bar with moving averages
            train_pbar.set_postfix({
                'loss': f'{current_loss:.4f}',
                'acc': f'{current_accuracy:.2f}%',
                'samples': epoch_train_samples
            })
        
        # Store final training metrics
        final_train_loss = current_loss
        final_train_accuracy = current_accuracy
        
        # Validation Phase
        model.eval()
        running_val_loss = 0.0
        epoch_val_correct = 0
        epoch_val_samples = 0
        
        # Create progress bar for validation batches
        val_pbar = tqdm(validationDataLoader, 
                       desc=f'Validation Epoch {epoch+1}/{epochs}',
                       position=1, 
                       leave=False)
        
        with torch.no_grad():
            for batch_idx, (set_features, access_features, cache_features, labels) in enumerate(val_pbar, 1):
                # Combine features
                cache_features_flat = cache_features.reshape(-1, 17*9)
                combined_features = torch.cat([set_features, access_features, cache_features_flat], dim=1)
                
                # Move to device
                combined_features = combined_features.to(device)
                labels = labels.to(device)
                
                # Forward pass
                outputs = model(combined_features)
                loss = criterion(outputs, labels)
                
                # Update running loss
                batch_size = labels.size(0)
                running_val_loss += loss.item() * batch_size

                
                # Calculate accuracy
                predicted = outputs.argmax(dim=1)
                true_labels = labels.argmax(dim=1)
                batch_correct = (predicted == true_labels).sum().item()
                
                # Update totals
                epoch_val_samples += batch_size
                epoch_val_correct += batch_correct
                
                # Calculate current accuracy
                current_accuracy = 100 * epoch_val_correct / epoch_val_samples
                current_loss = running_val_loss / epoch_val_samples
                
                # Update validation progress bar
                val_pbar.set_postfix({
                    'loss': f'{current_loss:.4f}',
                    'acc': f'{current_accuracy:.2f}%',
                    'samples': epoch_val_samples
                })
        
        # Store final validation metrics
        final_val_loss = current_loss 
        final_val_accuracy = current_accuracy
        
        # Update epoch progress bar with final metrics
        epoch_pbar.set_postfix({
            'train_loss': f'{final_train_loss:.4f}',
            'train_acc': f'{final_train_accuracy:.2f}%',
            'val_loss': f'{final_val_loss:.4f}',
            'val_acc': f'{final_val_accuracy:.2f}%'
        })
        
        # Print final statistics for the epoch
        print(f'\nEpoch {epoch+1}/{epochs}:')
        print(f'Training Loss: {final_train_loss:.4f}, '
              f'Training Accuracy: {final_train_accuracy:.2f}%')
        print(f'Validation Loss: {final_val_loss:.4f}, '
              f'Validation Accuracy: {final_val_accuracy:.2f}%\n')
        

        
# test train function
train(model, trainingDataLoader, validationDataLoader, optimizer, crossEntropyLoss, device, epochs=epochs)

I don’t fully understand your approach to shuffle the batch as the order of samples inside a batch does not matter. Is my understanding of your approach correct?

Yes, so the order of samples shouldnt matter. The reason I shuffled is becasue from my understanding that shuffling is necessary to ensure that the model does not see data in the same order every epoch, which can help improve generalization and prevent overfitting.

I was applying the softmax function before passing it to the CrossEntropyLoss. Which has been causing this issue.