Model Accuracy stuck at 0.5 though Loss is consistently decreasing

This is using PyTorch

I have been trying to implement UNet model on my images, however, my model accuracy is always exact 0.5. Loss does decrease.

I have also checked for class imbalance. I have also tried playing with learning rate. Learning rate affects loss but not the accuracy.

My architecture below ( from here )

""" `UNet` class is based on https://arxiv.org/abs/1505.04597

The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).

Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
    of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
    UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
    pathway (specified by upmode='upsample'), then an
    additional 1x1 2d convolution occurs after upsampling
    to reduce channel dimensionality by a factor of 2.
    This channel halving happens with the convolution in
    the tranpose convolution (specified by upmode='transpose')


    Arguments:
        in_channels: int, number of channels in the input tensor.
                     Default is 3 for RGB images. Our SPARCS dataset is 13 channel.
              depth: int, number of MaxPools in the U-Net. During training, input size needs to be 
                     (depth-1) times divisible by 2
        start_filts: int, number of convolutional filters for the first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose' for transpose convolution 

"""

class UNet(nn.Module):

    def __init__(self, num_classes, depth, in_channels, start_filts=16, up_mode='transpose', merge_mode='concat'):

        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for upsampling. Only \"transpose\" and \"upsample\" are allowed.".format(up_mode))
    
        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for merging up and down paths.Only \"concat\" and \"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use nearest neighbour to reduce depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode, merge_mode=merge_mode)
            self.up_convs.append(up_conv)
            

        self.conv_final = conv1x1(outs, self.num_classes)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            
            #https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/ 
            ##Doc: https://pytorch.org/docs/stable/nn.init.html?highlight=xavier#torch.nn.init.xavier_normal_ 
            init.xavier_normal_(m.weight)
            init.constant_(m.bias, 0)



    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)
        

    def forward(self, x):
        encoder_outs = []
         
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)
        
        # No softmax is used. This means we need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x

Parameters are :

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x,y = train_sequence[0] ; batch_size = x.shape[0]
model = UNet(num_classes = 2, depth=5, in_channels=5, merge_mode='concat').to(device)
optim = torch.optim.Adam(model.parameters(),lr=0.01, weight_decay=1e-3)
criterion = nn.BCEWithLogitsLoss() #has sigmoid internally
epochs = 1000

The function for training is :

import torch.nn.functional as f

def train_model(epoch,train_sequence):
    """Train the model and report validation error with training error
    Args:
        model: the model to be trained
        criterion: loss function
        data_train (DataLoader): training dataset
    """
    model.train()
    for idx in range(len(train_sequence)):        
        X, y = train_sequence[idx]  
        
        images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
        masks = Variable(torch.from_numpy(y)).to(device) 
        
        images = f.normalize(images, p=2)
        masks =  f.normalize(masks, p=2)
        
        optim.zero_grad()        
        outputs = model(images)
#         print(masks.shape, outputs.shape)
        loss = criterion(outputs, masks)
        loss.backward() 
        optim.step() # Update weights    
    total_loss = get_loss_train(model, train_sequence)

Loss is:

def get_loss_train(model, train_sequence):
    """
        Calculate loss over train set
    """
    model.eval()
    total_acc = 0
    total_loss = 0
    for idx in range(len(train_sequence)):        
        with torch.no_grad():
            X, y = train_sequence[idx]             
            images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
            masks = Variable(torch.from_numpy(y)).to(device) 
            
            images = f.normalize(images, p=2)
            masks =  f.normalize(masks, p=2)

            outputs = model(images)
            loss = criterion(outputs, masks)
            preds = torch.argmax(outputs, dim=1).float()
            acc = accuracy_check_for_batch(masks.cpu(), preds.cpu(), images.size()[0])
            total_acc = total_acc + acc
            total_loss = total_loss + loss.cpu().item()
    return total_acc/(len(train_sequence)), total_loss/(len(train_sequence))

Code which runs (calls) the functions:

for epoch in range(epochs):
    train_model(epoch, train_sequence)
    train_acc, train_loss = get_loss_train(model,train_sequence)
    print("Train Acc:", train_acc)
    print("Train loss:", train_loss)

Can someone help me identify as why is accuracy always exact 0.5?

Since you are using nn.BCEWithLogitsLoss, I assume you are dealing with a multi-label segmentation (i.e. each pixel might belong to zero, one, or more classes)?
Is this correct or would you like to apply a multi-class segmentation (i.e. each pixel belongs to one class only)?

Your accuracy calculation looks like to multi-class segmentation, as you are calling torch.argmax(outputs, dim=1), so you are currently mixing different use cases.

1 Like

Thanks @ptrblck for the reply. I really appreciate it.

I should have put the problem statement too. The end goal is segmentation in an image i.e. I have two classes only and each pixel/feature in the image can belong to one class.

I think it would come under binary class and single label. I think so as per what I read here.

I read a discussion here but couldn’t ascertain which one I should use.

Does it mean I should use nn.CrossEntropyLoss. I am sorry. I did read few forums but wasn’t sure what is going wrong so asked this question.

In that case you could either:

  • Use nn.BCEWithLogitsLoss as your criterion, make sure your model outputs logits with the shape [batch_size, 1, height, width] and use a target with the same shape containing your labels (0 and 1). For the accuracy calculation, you could apply a threshold of 0 to get the predicted class as: preds = outputs > 0.
  • Treat your binary segmentation as a multi-class segementation. Use nn.CrossEntropyLoss as the criterion, make sure your model outputs logits in the shape [batch_size, nb_classes, height, width] and your target is a LongTensor in the shape [batch_size, height, width] containing the class indices (0 and 1). To get the prediction, you could stick to the torch.argmax approach.
1 Like

@ptrblck

Use nn.BCEWithLogitsLoss as your criterion, make sure your model outputs logits with the shape [batch_size, 1, height, width]

Shouldn’t it be [batch_size, 2, height, width] because my outputs have two classes?

use a target with the same shape containing your labels (0 and 1).

Again, with classes 0 and 1, shouldn’t the number of classes in target be 2?

Edit:
(Taking two classes as : clear class and dark class.)

I just realized what you meant. Correct me if I am wrong but when you say 1 in [batch_size, 1, height, width], BCELosswithlogits is giving a probabilty for that class. A lower value is lower probabilty for that label, and a higher value would be higher value for that label.

However, I still don’t get what is meant when you said:

use a target with the same shape containing your labels (0 and 1)

With you suggestion, won’t I have only one label now (initially I have giving two label something like clear class and dark class)? The prediction from the model for this label should be 0 or 1 (as it is a segmentation model).

Sorry for so many questions.

If the classes are mutually exclusive, i.e. each sample/pixel is either clear or dark, you could still treat it as a binary classification use case, where 0 could mean clear and 1 could mean dark.
On the other hand. if each sample can belong to no class, one class or both, you would need to use 2 output neurons.

Hi @ptrblck,

I did what you suggested here:

  • Use nn.BCEWithLogitsLoss as your criterion, make sure your model outputs logits with the shape [batch_size, 1, height, width] and use a target with the same shape containing your labels (0 and 1). For the accuracy calculation, you could apply a threshold of 0 to get the predicted class as: preds = outputs > 0.

However, two things happen (which are very contradictory to each other):

  1. Accuracy comes out to be 86% though results look bad
  2. the preds array only contains False. I tested it using :
    print(np.where(preds.data.numpy()[0,0,:,:])[0])

Model trains now but results are poor :frowning: