How does PyTorch handle incomplete batches?

I know that when training, if you are using BatchNorm and have only 1 sample in your batch, then it can error out, which makes sense. I also know that you can use drop_last=True in your dataloader for those situations. But BatchNorm aside, how does PyTorch handle incomplete batches in training and inference? Is it able to deal with the different batch sizes or does it somehow pad the batch?

And the second part, similar question, does it matter at all if the size of your dataset is divisible by batch size when doing inference (validation/test)?

Hy @bfeeny, considering the last mini_batch is not the same size, as the rest. I guess one can implement a batch loader with variable size. Also the pytorch model need (batch_size,width,height,channel) where batch size can vary.

There is no padding involved and is not needed, since you never specify the batch size as a model attribute. As you’ve already explained, a batch size of 1 (potentially the last batch returned by the DataLoader) could create an error, if specific layers (such as batchnorm) cannot calculate some statistics using a single sample. The workaround would be to drop this batch as you’ve mentioned.

As @Usama_Hasan said, the input tensor should contain the batch dimension (usually in dim0) and your model will be able to deal with arbitrary batch sizes (assuming your system doesn’t run out of memory).

Usually you would call model.eval() while using the validation dataset or executing the inference, which switches the behavior of some modules. E.g. batchnorm layers will use their running stats and will thus not need to calculate the stats from the input batch and will not raise any errors for a single sample in the batch.

1 Like