Simple CNN for object counting only works with batch size 1

Hello!

For the last 2 weeks I’ve been stuck trying to count balls from a synthetic dataset I generated.
When I set a batch size higher than 1, the network predicts the average value all the time.
Otherwise, the network works great in train/val

  • Dataset: 5000 images with grey background and blue balls of the same size around. There are between 1 and 10 balls per image.
  • Optimizer: Adam with learning rate of 1e-4
  • Normalization: for the normalization of the images I used:
transforms.Compose([
        transforms.Resize((128,128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5],
                             [0.5, 0.5, 0.5])
    ])
  • Loss: L1 (Mean absolute error)
  • Network:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 15, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2))
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(12615, 1000)
        self.fc2 = nn.Linear(1000, 1)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc1(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

I am bit frustrated because this is the easiest example of object counting and I can not make it work properly.

How can I avoid predicting the average value (5) when raising the batch size?

PD: I tried many things: cyclic LR, Batchnorm2d, deeper networks, all the possible optimizers…

Thanks!

Are you using the summed loss or the mean (which is the default)?
Also, did you try to use MSELoss instead of L1Loss?

Since you have a fixed number of balls (1 to 10), you could also see this use case as a multi-class classification and use nn.CrossEntropyLoss instead. However, extrapolating wouldn’t work, if that’s in your use case.

Thanks @ptrblck for the quick reply!

Can you explain how to use a summed loss?
I think that in my case I just average all the losses:

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        successes = 0.0
        total = 0.0

        for inputs, labels in tqdm(dataloaders[phase]):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Set gradients to zero
            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)

                # Counting L1 loss
                loss =  criterion(outputs, labels.float())

I tried MSELoss with the same results.
I also considered nn.CrossEntropyLoss but it will apply the same penalty to an error of 7 (instead of 8) than an error of 1 (instead of 8). I agree it may work for this example, but it won’t work if I scale the difficulty of the task.

Thanks!

You would have to explicitly set it in the creation of the criterion via reduction='sum'.
Weird behavior, indeed.
Could you add a F.relu between the last two linear layers?
Also, how large are your images?

The results with reduction='sum' and reduction='mean' are slightly the same (I had to adapt the way I was displaying the loss). The success rate still stuck at 10% (predicting only 5s)
The F.relu didn’t provide any significant improvement.
The images are 128x128, it’s just a dataset I made for testing.

What else can I try? How can I debug where is the problem? Honestly, I am running out of ideas.

Thanks again

I would try to scale down the problem size and try to overfit a smaller data sample.
E.g. use only 10 batches of the desired batch size and try to make your model converge.

Also, another maybe “unusual” value is the high number of input features to the first linear layer.
I would try to lower this number by either resizing the images or adding another conv/pool layer.

I am using only 44 samples (and batch size 4) but I still can’t overfit the model when BS > 1
I resized the images to 64x64 and now the input to the FC is smaller:

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 15, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2))
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(2535, 512)          
        self.fc2 = nn.Linear(512, 1)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

Unfortunately, the model keeps predicting the average value.

Do you know any examples of object counting for computer vision?

All the object counting examples I found (Surprisingly, 99% of Pytorch examples are classifiers) on the internet use a detector (such as YOLO) or density maps (for crowd counting). Maybe density maps is the way to go.
Still disappointing that I can’t solve such an easy task of counting balls with a Pytorch CNN.

Thanks for your help @ptrblck

I published the clean code here as a showcase of the issue.

Thanks for the executable code, that was really helpful.

You are accidentally broadcasting the loss, since you have a mismatch in the output and target tensors.
While your output has the shape [batch_size, 1], the target has [batch_size].

This yields to a broadcasting as seen here:

# your code with the broadcasting
output = torch.randn(4, 1)
target = torch.randn(4)

criterion = nn.L1Loss(reduction="none")

loss = criterion(output, target)
print(loss) # you only want the diagonal
> tensor([[1.0231, 2.3743, 2.4857, 2.3248],
        [1.5896, 0.2385, 0.1270, 0.2879],
        [1.7572, 0.4061, 0.2946, 0.4555],
        [1.2650, 0.0862, 0.1976, 0.0368]])

# fixed
target = target.unsqueeze(1)
loss = criterion(output, target)
print(loss) 
> tensor([[1.0231],
        [0.2385],
        [0.2946],
        [0.0368]])

You should also get a warning such as:

UserWarning: Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

Use this line of code to calculate the loss and it should work:

loss =  criterion(outputs, labels.float().unsqueeze(1))

We were all in the same situation, so please don’t be disappointed. :wink:

3 Likes

Thanks a million! I knew it was a silly mistake!

1 Like