I get a much better result with batch size 1 than when I use a higher batch size

I am doing regression on an image, I have a fully CNN (no fully connected layers) and Adam optimizer. For some reason unknown to me when I use batch size 1, my result is much better (In testing is almost 10 times better, in training more than 10 times) in training and testing as oposed to using higher batch sizes (64,128,150), which is contraty to what people have apparently found. My loss is MSE. I would like to know if you someone has run into this or knows what’s going on.

I also have exactly the same initialization when I do the training. Moreover, I examined with my training every epoch and this holds regardless of that.

Attached is my code.

This is my data loader:

class DriveData(Dataset):
    
    def __init__(self,transform=None):
        self.xs=pd.read_csv('data/train_input.csv')
	self.ys=pd.read_csv('data/train_output.csv')
        self.x_data = torch.from_numpy(np.asarray(self.xs,dtype=np.float32))
    	self.y_data = torch.from_numpy(np.asarray(self.ys,dtype=np.float32))
    def __getitem__(self,index):
        
        
        return self.x_data[index], self.y_data[index]
    
    def __len__(self):
        return len(self.xs)

dset_train = DriveData()
train_loader = DataLoader(dset_train, batch_size=1,shuffle=True, num_workers=4)

My training function:

def train(model,device,train_loader,optimizer,epoch):
    
    model.train()
    for batch_idx, (data,target) in enumerate(train_loader):
        data, target = data, target
	data, target = Variable(data), Variable(target)         
	optimizer.zero_grad()
        output = model(data)
        loss = F.mse_loss(output,target)

model = Net()
    optimizer = optim.Adam(model.parameters(),lr=.001)       
    for epoch in range(1):
        train(model,device,train_loader,optimizer,epoch)

Batch size 1: Training loss: 0.000812 testing loss 0.002547

Batch size 128: Training loss 0.0171 testing loss 0.0226

Thanks

There are papers stating a smaller batch size might generalize better, and it’s not uncommon to see the effect.
This paper is an example.

Is the loss using the larger batch size eventually decreasing?

PS: You can post code using three backticks. I’ll edit your code for readability reasons.

1 Like

Yes, it does. The algorithm works as expected. It’s just weird I get better results with batch size 1. Even in the paper you mentioned they suggest between 32-512, so that’s why I am dazzled. So, I was thinking maybe I have an error on my code or there is something I am ignoring.

Is the final result the one you’ve posted in your first post, i.e. 10 times lower for a batch size of one?
One common mistake would be, if you set size_average=False for the loss function, but apparently you leave it as the default value, i.e. True.

Could you post your model architecture, so that we can be sure there is nothing wrong?

Yes, and yes, I have the default set up in the loss.

This my model

class Net(nn.Module):
     def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1,64,kernel_size=3,padding=(1,1))
        self.conv1_bn = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,padding=(1,1))
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64,8,kernel_size=3,padding=(1,1))
        self.conv3_bn = nn.BatchNorm2d(8)
        self.conv4 = nn.Conv2d(8,8,kernel_size=3,padding=(1,1))
        self.conv4_bn = nn.BatchNorm2d(8)
        self.conv5 = nn.Conv2d(8,8,kernel_size=3,padding=(1,1))
        self.conv5_bn = nn.BatchNorm2d(8)
        self.conv6 = nn.Conv2d(8,16,kernel_size=3,padding=(1,1))
        self.conv6_bn = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16*225,1)
        
    def forward(self,x):
	    x = x.float()        
	    x = x.view(-1,1,15,15)
        x = F.relu(self.conv1_bn(self.conv1(x)))
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = F.relu(self.conv4_bn(self.conv4(x)))
        x = F.relu(self.conv5_bn(self.conv5(x)))
        x = F.relu(self.conv6_bn(self.conv6(x)))
        x = self.conv7(x)
        x = x.view(-1,225)
        
        return x

My input is a 15x15 patch and my output as well. That’s why I have 225.

Since you have BatchNorm layers in your model, I would think a batch size of one will perform worse than a larget batch size, because the running stats could be quite shaky. Apparently that’s not the case here.

Could you explain your dataset a bit? What kind of data do you have?
You could try to play around with the momentum in the BatchNorm layers a bit and see how your results for the large batch size change.

Exactly.

I am basically doing denoising. I have a 300x300 image, from which I make 15x15 patches that overlap, so, the input is the noisy image and the output the ground truth image. My training comes from the left half of the image. I also do some rotation for my training. My testing comes from the right half.

Ok, thanks. I will try as well modifying significantly the learning rate.

I was playing with momentum as you suggested. Initially, I modified the momentum directly in my code, I didn’t really notice that much of a difference in my result. Then I did this simple test.

a = torch.Tensor([[-1.5,0],[0.5,-.8]])

b = torch.nn.functional.batch_norm(a,torch.mean(a,dim=1),torch.var(a,dim=1),momentum=0.9)

I tried different values for momentum, but the output was always the same.

I also read the definition of mometum directly from https://pytorch.org/docs/stable/nn.html but honestly I don’t completely understand what they mean by estimated statistic and new observed value.

Finally, I realize if I remove batch normalization at all, I get results that make sense e.g. batch size 32 is than batch size 1. Which means I am not applying batch normalization correctly. I would like if you could shed me some light regarding the usage or the exact interpretation of the momentum.

Thanks

BN’s functional form has default training=False. You need training=True to let momentum have an effect.

Also, pay attention to the note under https://pytorch.org/docs/master/nn.html?highlight=batch_norm#torch.nn.BatchNorm2d . A momentum of 0.9 is probably not what you want.

If my complete dataset can load in RAM at once then how should I choose batch size ?

The batch size is independent from the data loading and is usually chosen as

  1. what works well for your model and training procedure (too small or too large might degrade the final accuracy)
  2. which GPUs you are using and what fits into your device memory. Generally you can increase the device utilization by increasing the workload e.g. through the batch size.
1 Like