Model is learning to predict just one class

Hey all,

I created a model that processes two inputs in the same network sharing weights. This network, i implemented stacked dilated convolutions sharing weights inside each layer, as defined at https://arxiv.org/abs/1904.03076.

I tried to overfit the model during 100 epochs using the same batch, but my network keeps predicting just one class (1, 2 or 3).

My model class is

class SDCNetwork(nn.Module):
    def __init__(self, num_layers, input_size, n_conv, kernel_sizes, n_kernels, dilations):
        super(SDCNetwork, self).__init__()
        self.input_size = input_size
        self.n_conv = n_conv
        self.kernel_sizes = kernel_sizes
        self.n_kernels = n_kernels
        self.dilations = dilations
        self.num_classes = 3
        
        layers = nn.ModuleList([])
        # Iterate through the number of layers
        for i in range(num_layers):
            layers.append(
                SDCLayer(input_size=self.input_size[i],
                            n_conv=self.n_conv,
                            kernel_sizes=self.kernel_sizes,
                            n_kernels=self.n_kernels[i],
                            dilations=self.dilations
                ) 
            )

        
        self.features = nn.Sequential(*layers)

        # Class activation map llayer similarly to ACOL paper output
        self.cam = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding='same'),
            nn.SiLU(),
            nn.Conv2d(256, self.num_classes, kernel_size=1, padding='same'),
            nn.SiLU()
        )

        # Global average pooling
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)



    def forward(self, x1, x2):
        # Process different inputs with the same weights (SDC Network)
        x1, x2 = self.features(x1), self.features(x2)
        
        # Concatenate the 256 x 512 image to be 512 x 512
        x = torch.cat((x1, x2), dim=2)
        
        # Class activation map
        x = self.cam(x)
        # Global average pooling
        x = self.gap(x)
        # Flatten
        x = x.view(x.size(0), -1)
        
        return x

My training is

def run_one_batch(model, batch, criterion, optimizer, scheduler, device, num_epochs=100):

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        model.train()  # Set model to training mode

        running_loss = 0.0
        running_corrects = 0
            
        # Iterate over data
        loop = tqdm(batch, total=len(batch), leave=False)
        images, patient = batch[0]['image'], batch[0]['patient']
        label = batch[1]
        lung1, lung2 = images[0], images[1]
        lung1, lung2 = lung1.to(device), lung2.to(device)
        label = label.to(device)

        # Forward
        outputs = model(lung1, lung2)
        probs = F.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)
        loss = criterion(outputs, label)

        # Zero the parameter gradients
        optimizer.zero_grad()
        # Backward + optimize only if in training phase
        loss.backward()
        optimizer.step()
        loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        loop.set_postfix(loss=loss.item(), acc=(float(torch.sum(preds == label).item())/float(outputs.size(0))))

        # Statistics
        running_loss += loss.item()
        running_corrects += torch.sum(preds == label).item()
        scheduler.step(running_loss)

        # Compute epoch metrics and loss
        epoch_loss = float(running_loss) / float(len(batch))
        epoch_acc = float(running_corrects) / float(len(batch))

The parameters are

 optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, betas=(0.9, 0.999), weight_decay=1e-4) # optimize just the learnable parameters
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=1, verbose=True) # lr decay on plateau
    criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean')

I would be so glad if someone could help me this problem that I’m facing. It’s actually getting me stuck.
Thank you!

I don’t know exactly how this code snippet is supposed to work:

        # Iterate over data
        loop = tqdm(batch, total=len(batch), leave=False)
        images, patient = batch[0]['image'], batch[0]['patient']

but it seems as if you are using a single sample only?
Could you explain this approach a bit more and how you are iterating the entire dataset?

Hey! Thanks for you reply.

This code snipped is just iterating over just one batch. My goal here is just to make sure everything is working correctly. But as I said, the model is not overfitting, it’s just predicting just one class.

When I run with the entire dataset, the same happens: the model learns to guess, for example, 1 for all the examples.

I realized that my output is giving the same values for each class

Probs: tensor([[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333]], device=‘cuda:0’, grad_fn=)