"Cleaning" dataset before DataLoader creation can significantly improve performance

Hi,

I recently discovered by chance that if first extract the data and target tensors from a dataset and then create a DataLoader using a custom TensorDataset, i get huge performance improvement (up to 2-3 times faster training time).

Here is some code:

def clean_dataset(dataset: torch.utils.data.Dataset,
                  max_devide: bool = True,
                  ) -> torch.utils.data.TensorDataset:
    """
    Extracts data and targets from the dataset and normalizes the data. It can also significantly
    improve simulation performance.
    """
    X, y = dataset.data.unsqueeze(1).float(), dataset.targets.long()
    if max_devide:
        X /= X.max()
    return torch.utils.data.TensorDataset(X, y)

# Download mnist training data
training_data = datasets.MNIST(
    root="../datasets",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download mnist test data
test_data = datasets.MNIST(
    root="../datasets",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Clean datasets (If I uncomment the next 2 lines, I get huge performance improvements)

# training_data = clean_dataset(training_data)
# test_data = clean_dataset(test_data)

# Create DataLoaders
train_loader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

# AND THEN PERFORM SOME STANDARD CLASSIFICATION EXPERIMENT USING A SIMPLE FEED-FORWARD NETWORK

Given the above, I have 2 questions:

  1. Why is training performance so much better when I create the Data Loaders using the above trick?
  2. Are there any unexpected side-effects of this approach? Should I take any special precautions? Classification accuracy does not seem to change, only training speed.

Thank you in advance for your help!

Michalis

By default MNIST is loading each sample as a PIL.Image and then applies all transformations on it returning the tensor (assuming ToTensor was used). You are skipping the image to tensor transformation and are thus seeing a performance improvement.