Exploding memory when using variable size batches

Hi all,

Currently I am working on a computer vision project where during inference the number of patches to be classified differs quite heavily (in the 20 - 400) range. When I run these variable sized batches through the network the used memory absolutely explodes. The memory usage (checked with htop) is almost twice as high as when I use a fixed size input batch of 400 (the max).

Is there a logical reason for why this happens? I could split the batches into equally size chunks but this would hamper performance since more forward passes are required.

Here is some code for reproduction, change the constant 250 to np.random.randint(1, 250) to see the difference:

import torch.nn as nn
import torch
from torchvision import models
import numpy as np

class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        self.features = models.alexnet(pretrained=True).features[:5]
        for layer in self.features.parameters():
            layer.requires_grad = False
        self.forward_pass = nn.Sequential(
            nn.Linear(192 * 5 * 5, 10000),
            nn.ReLU(),
            nn.Linear(10000, 5000),
            nn.ReLU(),
            nn.Linear(5000, 1000)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 192 * 5 * 5)
        return self.forward_pass(x)

net = network()
for layer in net.parameters():
    layer.requires_grad = False
net.eval()

with torch.no_grad():
    while True:
        dataset = torch.rand((250, 3, 50, 50))
        net.forward(dataset)