Sample balanced data

i am using

FashionMNIST_train_dataset = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST',
    train = True,
    download = True,
    transform = transforms.Compose([

and I want SPLIT this data to N balanced labeled data, and the rest (60000-N).
how can I do that?

1 Like

You could use e.g. sklearn.model_selection.StratifiedShuffleSplit (or any other suitable splitting function) to create the split indices.
Once you have the indices, you could pass your dataset together with the split indices to a Subset and pass it then to a DataLoader.

1 Like

this is what i did:

def GetIndices (data,N):
  # given data, divide it to N balanced labels samples.
  # return
  Nclasses = len(data.classes)
  Ndata = len(
  count = [[] for x in range(Nclasses)]   #init. count contain the idx of each label                                               
  for idx, (_,label) in enumerate(data):
  # this is to check that this is balance label data 
  #for i in range(len(count)):
  #  Nsamples= len(count[i])
  #  print("label " + str(i) + " has " + str(Nsamples) + " images") 

  NsampleToTakeFromEachLabel = math.floor(N / Nclasses)
  #print("total data len = " + str(Ndata) + ",  Nclasses = " +  str(Nclasses) + ",  N=" +str(N) + ",  NsampleToTakeFromEachLabel = " +str(NsampleToTakeFromEachLabel))
  not_picked_indices = [];
  for i in range(len(count)):
    picked_idx = random.sample(count[i],NsampleToTakeFromEachLabel);
    for j in range(len(count[i])):
      if count[i][j] in picked_idx:
      else :
  return picked_indices , not_picked_indices

def DivideData (train_dataset,test_dataset,N):
  picked_indices , not_picked_indices = GetIndices(train_dataset,N)

  trainloader =, picked_indices))
  testloader1 =, not_picked_indices))
  testloader2 =, shuffle=True)

  return trainloader , testloader1 ,testloader2

trainloader , testloader1 ,testloader2 = DivideData(FashionMNIST_train_dataset,FashionMNIST_test_dataset,100) 

any idea how to concat testloader1 with testloader2

You could use ConcatDataset on the Subset and test_dataset and pass this concatenated dataset to a DataLoader.