Restoring models when batch size is different

You could instance the original net with their weights and after replace the fc layers, something like:

net = Net(batch_size=4)
net.load_state_dict(torch.load(your_weights))
net.fc1 = nn.Linear(new_batch_size* 2 * 14, 2048)
net.fc2 = nn.Linear(2018, new_batch_size)

but this layers will initialize with random weights, so you will have to train the net again. Btw why are you using a batch dependent network?

1 Like