Split dataset in PyTorch for CIFAR10, or whatever

How to split the dataset into 10 equal sample sizes in Pytorch?
The goal is to train on each set of samples individually and aggregate their gradient to update the model for the next iteration.

cc @VitalyFedyunin for data loader

1 Like

You can use the DistributedSampler for this.

1 Like

How we can split 60,000 data(MNIST) into 10 parts in which the first part(data1) contains 6000 data, and the second part(data2) contains 6000 and so on. So how DistributedSampler will be used for this?

Below is the code in python using TensorFlow.

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.utils import np_utils

(X_train, y_train), (X_test, y_test) = mnist.load_data()

%worker_X[i] contains the i’th part of training data’s features.
%worker_y[i] contains the i’th part of training data’s label
worker_X =
worker_y =

% dataset size is 60000 . We want to split among 10 workers.
Batch = 60000//10

for i in range(10):
worker_X.append(X_train[i*Batch: Batch+i*Batch])
worker_y.append(y_train[i*Batch: Batch+i*Batch])

I am looking for the same thing in PyTorch.

@ohm The AWS tutorial goes over how to use the DistributedSampler to split your dataset into even parts: https://pytorch.org/tutorials/beginner/aws_distributed_training_tutorial.html#initialize-dataloaders. This is not an official PyTorch tutorial, but the “With multiprocessing” section in https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html describes how to use DistributedSampler to do this for MNIST.

1 Like

Thanks.
‘The AWS tutorial goes over how to use the DistributedSampler to split your dataset into even parts: pytorch.org/tutorials/beginner/aws_distributed_training_tutorial.html#initialize-dataloaders
In this example, I did not see such a thing. The ‘num_workers=workers’ in this example is different from what I am looking for. As I mentioned I am looking for split data among workers. the fixed number of data in each worker! And I need to access them when I call a specific worker.
https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html’ I saw it before, again it is not clear how it split data among workers!?! Even it is not clear that it did it or

A simple example to split dataset among n_workers:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
from worker import worker
from worker_new import worker_new

#the number of workers here:
n_workers = 4

testset = torchvision.datasets.CIFAR10(root=’./data’, train=False, download=True, transform=transforms.ToTensor())

#Split testset, you can access the data from worker 1 with Worker_data[1], and so on.
temp = tuple([len(testset)//n_workers for i in range(n_workers)])
Worker_data = torch.utils.data.random_split(testset, temp)

@Ohm My apologies for the previous reply not being clear. The splitting automatically occurs within the DistributedSampler based on the rank that you provide. This piece of code in the DistributedSampler would illustrate how this splitting is done per worker based on the rank.

1 Like