I’m using a custom IterableDataset and DataLoader to train my model due to the large amount of data - since I can’t store it all in memory at once.
My data consists of multiple files such that each file has thousands and thousands of examples.
I have set up my training loop and found that whenever I run an epoch I get the same training accuracy and training loss for every epoch. The same also applies for the validation. Why does that happen?
I’m starting to think that the issue is with the shuffling of the dataset. I do know that the DataLoader does automatic shuffling but this doesn’t apply for IterableDatasets. So to accomodate for that, I’ve created a BatchShuffleDataSet which is also a custom iterabledataset for batch shuffling.
This is my DatasetClass:
class ProtobufIterableDataset(IterableDataset):
def _init_(self, data_folder, scaling=False):
self.data_folder = data_folder
self.scaling = scaling
super()._init_()
if scaling:
scaler_path = "scalers.npz"
# Load the data once and store as class attribute instead of instance attribute to avoid BadZipFile error
if not hasattr(ProtobufIterableDataset, 'scaler_data'):
ProtobufIterableDataset.scaler_data = dict(np.load(scaler_path))
def _iter_(self):
# Worker Split Based on Number of Examples in Each File so sharding at example level
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
# Get list of all files
for file in sorted(os.listdir(self.data_folder)):
with open(os.path.join(self.data_folder, file), 'rb') as f:
cache_data = cache_state_pb2.CacheStateProto()
cache_data.ParseFromString(f.read())
# Each worker takes every nth example
for i, example in enumerate(cache_data.proto_examples):
if i % num_workers == worker_id:
processed_example = self.process(example)
yield processed_example
else:
print("Single Worker Mode")
for file in os.listdir(self.data_folder):
with open(os.path.join(self.data_folder, file), 'rb') as f:
cache_data = cache_state_pb2.CacheStateProto()
cache_data.ParseFromString(f.read())
for example in cache_data.proto_examples:
processed_example = self.process(example)
yield processed_example
This is my Batch Shuffler Class:
# Batch Shuffler Class
class BatchShuffleDataset(IterableDataset):
def _init_(self, dataset, batch_size):
super()._init_()
self.dataset = dataset
self.batch_size = batch_size
def _iter_(self):
batch = []
dataset_iter = iter(self.dataset)
for item in dataset_iter:
batch.append(item)
if len(batch) == self.batch_size:
random.shuffle(batch)
for shuffled_item in batch:
yield shuffled_item
batch = []
if batch:
random.shuffle(batch)
for shuffled_item in batch:
yield shuffled_item
Here is my training file:
import torch
from torch.utils.data import DataLoader
from new_data_loader import ProtobufIterableDataset, BatchShuffleDataset
from model import SimpleModel
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
# Constants
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
NUM_WORKERS = 4
# Load Training Data
trainingDataset = ProtobufIterableDataset('training_examples', scaling=True)
trainingDataset = BatchShuffleDataset(trainingDataset, batch_size=SHUFFLE_BUFFER_SIZE)
trainingDataLoader = DataLoader(trainingDataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
# Load Validation Data
validationDataset = ProtobufIterableDataset('validation_examples', scaling=True)
validationDataset = BatchShuffleDataset(validationDataset, batch_size=SHUFFLE_BUFFER_SIZE)
validationDataLoader = DataLoader(validationDataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
# Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel().to(device)
epochs = 10
learning_rate = 0.001
crossEntropyLoss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
def train(model, trainingDataLoader, validationDataLoader, optimizer, criterion, device, epochs=10):
"""
Training function with improved metric calculations, compatible with IterableDataset
"""
# Create outer progress bar for epochs
epoch_pbar = tqdm(range(epochs), desc='Training Progress', position=0)
for epoch in epoch_pbar:
# Training Phase
model.train()
running_train_loss = 0.0 # Running loss for the current epoch
epoch_train_correct = 0 # Total correct predictions for the epoch
epoch_train_samples = 0 # Total samples processed in the epoch
# Create progress bar for training batches
train_pbar = tqdm(trainingDataLoader,
desc=f'Training Epoch {epoch+1}/{epochs}',
position=1,
leave=False)
# For every batch
for batch_idx, (set_features, access_features, cache_features, labels) in enumerate(train_pbar, 1):
# Combine features
cache_features_flat = cache_features.reshape(-1, 17*9)
combined_features = torch.cat([set_features, access_features, cache_features_flat], dim=1)
# Move to device
combined_features = combined_features.to(device)
labels = labels.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = model(combined_features)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Update running loss
batch_size = labels.size(0)
running_train_loss += loss.item() * batch_size
# Calculate accuracy
predicted = outputs.argmax(dim=1)
true_labels = labels.argmax(dim=1)
batch_correct = (predicted == true_labels).sum().item()
# Update totals
epoch_train_samples += batch_size
epoch_train_correct += batch_correct
# Calculate current accuracy
current_accuracy = 100 * epoch_train_correct / epoch_train_samples
current_loss = running_train_loss / epoch_train_samples
# Update training progress bar with moving averages
train_pbar.set_postfix({
'loss': f'{current_loss:.4f}',
'acc': f'{current_accuracy:.2f}%',
'samples': epoch_train_samples
})
# Store final training metrics
final_train_loss = current_loss
final_train_accuracy = current_accuracy
# Validation Phase
model.eval()
running_val_loss = 0.0
epoch_val_correct = 0
epoch_val_samples = 0
# Create progress bar for validation batches
val_pbar = tqdm(validationDataLoader,
desc=f'Validation Epoch {epoch+1}/{epochs}',
position=1,
leave=False)
with torch.no_grad():
for batch_idx, (set_features, access_features, cache_features, labels) in enumerate(val_pbar, 1):
# Combine features
cache_features_flat = cache_features.reshape(-1, 17*9)
combined_features = torch.cat([set_features, access_features, cache_features_flat], dim=1)
# Move to device
combined_features = combined_features.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(combined_features)
loss = criterion(outputs, labels)
# Update running loss
batch_size = labels.size(0)
running_val_loss += loss.item() * batch_size
# Calculate accuracy
predicted = outputs.argmax(dim=1)
true_labels = labels.argmax(dim=1)
batch_correct = (predicted == true_labels).sum().item()
# Update totals
epoch_val_samples += batch_size
epoch_val_correct += batch_correct
# Calculate current accuracy
current_accuracy = 100 * epoch_val_correct / epoch_val_samples
current_loss = running_val_loss / epoch_val_samples
# Update validation progress bar
val_pbar.set_postfix({
'loss': f'{current_loss:.4f}',
'acc': f'{current_accuracy:.2f}%',
'samples': epoch_val_samples
})
# Store final validation metrics
final_val_loss = current_loss
final_val_accuracy = current_accuracy
# Update epoch progress bar with final metrics
epoch_pbar.set_postfix({
'train_loss': f'{final_train_loss:.4f}',
'train_acc': f'{final_train_accuracy:.2f}%',
'val_loss': f'{final_val_loss:.4f}',
'val_acc': f'{final_val_accuracy:.2f}%'
})
# Print final statistics for the epoch
print(f'\nEpoch {epoch+1}/{epochs}:')
print(f'Training Loss: {final_train_loss:.4f}, '
f'Training Accuracy: {final_train_accuracy:.2f}%')
print(f'Validation Loss: {final_val_loss:.4f}, '
f'Validation Accuracy: {final_val_accuracy:.2f}%\n')
# test train function
train(model, trainingDataLoader, validationDataLoader, optimizer, crossEntropyLoss, device, epochs=epochs)