Why is the duration of an epoch decreases for larger batch sizes?

I’m trying to learn torch (and ML in general) by training a fairly simple model on MNIST. During training I observe that, duration of an epoch decreases for larger batch sizes. This doesn’t make any sense to me. Doesn’t larger batch size mean that my model processes more data per epoch? Is it due to a bug in my program or is this something expected for training in general?

For example, when batch size is 60k (entire dataset) an epoch takes merely about 1.5 seconds. When batch size is 64, it takes about 8 seconds.

My code is below.

1 - Main training loop

# fetch data
    data_dir = '../data/mnist/'
    apply_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=apply_transform)
    test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=apply_transform)
    train_loader, test_loader = utils.load_data_to_GPU(train_dataset, test_dataset, args)

    # initialize the model
    model = torch.nn.DataParallel(CNNMnist()).to(args.device)
    criterion = torch.nn.NLLLoss().to(args.device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 
    # training
    test_accs, per_class_accs = [], []
    for epoch in tqdm(range(args.global_ep)):
        print(f'| Global Epoch : {epoch+1} |')
        for _, (images, labels) in enumerate(train_loader):
            outputs = model(images)
            batch_loss = criterion(outputs, labels)
        test_acc, per_class_acc = infer.get_test_accuracy(model, test_loader, num_classes=len(test_dataset.classes))
        print(f'| Test Acc : {test_acc} ')
        print(f'| Per-class Acc : {per_class_acc} |\n' )   

2 - Data loading functionalities

def load_data_to_GPU(train_dataset, test_dataset, args):
     # put all data to GPU
    train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)
    _, (images, labels) = next(enumerate(train_loader))
    images, labels = images.to(args.device), labels.to(args.device)
    train_loader = DataLoader(torch.utils.data.TensorDataset(images, labels), batch_size=args.local_bs, shuffle = True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
    _, (images, labels) = next(enumerate(test_loader))
    images, labels = images.to(args.device), labels.to(args.device)
    test_loader = DataLoader(torch.utils.data.TensorDataset(images, labels), batch_size=len(test_dataset), shuffle=False)
    return train_loader, test_loader

3 - Inference function

def get_test_accuracy(model, test_loader, num_classes=10):
    Returns the overall accuracy and per-class accuracy
    confusion_matrix = torch.zeros(num_classes, num_classes)
    total_samples, correctly_labeled_samples = 0.0, 0.0
    with torch.no_grad():
            for _, (images, labels) in enumerate(test_loader):# forward-pass to get predictions of the current batch
                outputs = model(images)
                _, pred_labels = torch.max(outputs, 1)
                pred_labels = pred_labels.view(-1)
                # get num of correctly predicted images in the current batch
                correctly_labeled_samples += torch.sum(torch.eq(pred_labels, labels)).item()
                # fill confusion_matrix
                total_samples += len(labels)
                for t, p in zip(labels.view(-1), pred_labels.view(-1)):
                        confusion_matrix[t.long(), p.long()] += 1

    accuracy = (correctly_labeled_samples/total_samples)*100
    per_class_accuracy = confusion_matrix.diag()/confusion_matrix.sum(1)
    return accuracy, per_class_accuracy

4 - My model

class CNNMnist(nn.Module):
    def __init__(self):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

I’m guessing that with 1.5 seconds for 60k samples, you are using a GPU. This result comes from the parallelization power of GPUs, quite simply… They are better at processing data in parallel, and as long as you have enough “room” (memory, parallel threads, cores…), it will be faster if you give it more data, because less data will be left after one batch…

“Real life” comparison: if a sawmill has the capacity to saw 20 logs at the same time, but is only fed one at a time, it will take longer than giving it 20 at a time…


Both measurements are with GPU where the data is pre-loaded to GPU before the training loop begins.

Still, I don’t get how could duration of an epoch is shorter for bigger batch sizes. Do you mind to elaborate a bit more?

It’s simple.
Say your GPU has 4 cores and each core is capable of processing 64 images in a sec. And all of the 4 cores work in parallel.

So your GPU will take the same amount of time to process 64 images and 256 images.

So to process 256 images, it will 1 sec if you give batch size as 256 but will take 4seconds if you give batch size as 4 and will take 8 seconds if you give batch size as 32.

Try Giving batch size as 1 and it will take even more time

Ah gotcha. Thank you.

1 Like