Restoring models when batch size is different

I have a very simple model with 2 conv and 2 linear layers.
the input and output of the linear layers depends on the batch size.
I am having problem(size mismatch) when I try to save the model weights with a batch size and restore with another batch size.
How can I make save/restore independent from the batch size? Can you point me to a sample model?

Thanks!

The model should be independent of the batch size, can you post your code please?

My model:


class Net(nn.Module):
    def __init__(self, batch_size):
        super(Net, self).__init__()
        self.batch_size = batch_size
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(2, 1, 3)
        self.fc1 = nn.Linear(self.batch_size * 2 * 14, 2048)
        self.fc2 = nn.Linear(2048, self.batch_size)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(self.batch_size * 2 * 14)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

and this is how i am calling it:

y_pred = model(x)
The shape of the x is (batch_size, no_of_channels, h, w)

You could instance the original net with their weights and after replace the fc layers, something like:

net = Net(batch_size=4)
net.load_state_dict(torch.load(your_weights))
net.fc1 = nn.Linear(new_batch_size* 2 * 14, 2048)
net.fc2 = nn.Linear(2018, new_batch_size)

but this layers will initialize with random weights, so you will have to train the net again. Btw why are you using a batch dependent network?

1 Like

Because while training, it is a lot faster if I feed with a big batch size like 512.
But on production I wont always have 512 images so I may call it with single image, 4 images or any amount.
I wonder how ResNet implementation is doing it because with ResNet I can train with any batch size and call with any other batch size.

The workaround I have found for now is I train with 512 batch size and while calling on prod I fill the missing part with replicate data.
But this is not a good solution.
Thanks anyways;
I will dig into ResNet implementation later to see how it is doing.

Also I dont know how to make a batch independent network but still able to train/call it with batches. Do you have any suggestion to make the above network “batch independent”?

The batch are usually only for training, and in production you usually call the network with only one image but you still have to pass the sample as a batch tensor (ie [1,no_channels,h,w]). Think that internally the network works with only one sample and you pass a batch tensor to calculate the loss across all the samples in the batch, and next average them.

So, the network is always batch independent

If you still need help , i encountered a similar problem and found a way to work around.
I was facing this problem because my batch size was 64 and total datapoints / samples were 20000 so the last batch couldn’t be divided into 64 and i got size mismatch error. This was because i hardcoded into the first linear layers after all conv layers like this :
self.lin1 = nn.Linear(out.reshape(64,8*20*20))
instead of doing this , I passed the length of each training batch as a parameter to the forward function.
and passed len(xb) as the parameter and changed self.lin1 to self.lin1 = nn.Linear(out.reshape(batch_size , 8*20*20))
where batch_size is the current batch size.

Well i also missed that you could always do nn.Linear(out.reshape(-1,8*20*20))
Without sending a batch size parameter manually.