Get a batch that has images of the same class

I would like to generate batches in such a way that all the datapoints/images in the batch belong to the same class. For example if batch size is set 32 and considering MNIST dataset, the batch should have images of the same class, say all ‘0’ or all ‘1’.
The reason I want to do this is as part of research

I have written the following code to get the batch in the required way. however training is not taking place(the training error was actually increasing with epochs):

# Generate dataloaders for each class

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 256, shuffle=False)

A = {}
B = {}
C = {}
# B and C are dictionaries that contain DataLoaders for each class

batch_size = 32

for i in range(0, 10):
    idx = mnist_train.train_labels==i
    A['mnist_train_{}'.format(i)] = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
    A['mnist_train_{}'.format(i)].data =[idx]
    A['mnist_train_{}'.format(i)].targets = mnist_train.targets[idx]

    idx = mnist_test.test_labels==i
    A['mnist_test_{}'.format(i)] = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())
    A['mnist_test_{}'.format(i)].data =[idx]
    A['mnist_test_{}'.format(i)].targets = mnist_test.targets[idx]

    B['train_loader_{}'.format(i)] = DataLoader(A['mnist_train_{}'.format(i)], batch_size = batch_size)
    C['test_loader_{}'.format(i)] = DataLoader(A['mnist_test_{}'.format(i)], batch_size = batch_size)

I use the above dataloaders in the following training function:

def epoch_all(loaders, model, opt=None):
  total_loss, total_err = 0., 0.
  for loader in loaders.values():
    for X,y in loader:
      X,y =,
      yp = model(X)
      loss = nn.CrossEntropyLoss()(yp,y) 
      if opt:
      total_err += (yp.max(dim=1)[1] != y).sum().item()
      total_loss += loss.item() * X.shape[0]

  if opt:
    return total_err / 60000, total_loss / 60000
    return total_err / 10000, total_loss / 10000

The above function is run using:

    for t in range(10):
        train_err, train_loss = epoch_all(B, model_cnn, opt)
        test_err, test_loss = epoch_all(C, model_cnn)

Would be grateful for any help