Using MSELoss instead of CrossEntropy for Ordinal Regression/Classification

Hi everyone,
I have come across multiple examples that illustrate the working of a CNN foe classification tasks. However, there is very little out there that actually illustrates how a CNN can be modified for a regression task, particularly a ordinal regression tasks that can have outputs in the range of 0 to 4.

I understand that this problem can be treated as a classification problem by employing the cross entropy loss. Although, I think MSELoss() would work better since you would prefer a 0 getting miss-classified as a 1 rather than a 4.

I use the torchvision pre trained model for this task and then use the CrossEntropy loss.

model = models.resnet18(pretrained = True)
fc_in_features = model.fc.in_features
model.fc = nn.Linear(fc_in_features,5)

# DEFINE A FUNCTION TO TRAIN THE MODEL

def train_model(model, dataloaders, criterion, optimizer, lr_scheduler,model_path, num_epochs=25):
    since = time.time()

    val_acc_history = []
    val_loss_history = []
    train_acc_history = []
    train_loss_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(1, num_epochs+1):
        
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            all_preds = []
            all_labels = []
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    # Get model predictions
                    _, preds = torch.max(outputs, 1)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                all_preds.append(preds)
                all_labels.append(labels)
            epoch_loss = running_loss / len(dataloaders[phase].sampler)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].sampler)
            all_labels = torch.cat(all_labels, 0)
            all_preds = torch.cat(all_preds, 0)
            lr_scheduler.step(epoch_loss)

            print('{} Loss: {:.4f} - Acc: {:.4f} '.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)
                val_loss_history.append(epoch_loss)
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_loss_history.append(epoch_loss)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, train_loss_history, val_loss_history, train_loss_history, val_acc_history


# TRAIN THE NETWORK 
EPOCHS = 20
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3,verbose = True)
model, train_loss_history, val_loss_history, train_acc_history, val_acc_history = train_model(model=model, dataloaders=data_loaders, criterion=criterion,optimizer = optimizer,lr_scheduler=lr_scheduler, num_epochs=EPOCHS)

I am confused how to convert this code so as to use MSELoss or L1SmoothLoss. If I merely change the loss function, the dimensions do not match.

The solution I tried was to convert the labels into one hot encoding and to add a Softmax function to the output layer of the network. This then gives model outputs and targets of the similar shape. And then I use torch.argmax() to convert the output and targets into the range 0-4. However the model only predicts one class always which is weird.

I would want to understand 4 things:

  1. Do we convert the labels into one hot encodings in this case?
  2. Do we have to make changes to the networks output layer to be able to use MSELoss or L1SmoothLoss?
  3. How do we handle the mismatch of the dimensions when we use these losses?
  4. Assuming that we apply these loss functions, how do we convert the output of the model to the range 0-4 so as to calculate the accuracy.

Thank you in advance. I hope that this discussion can finally outline a clear pipeline to use for regression tasks.

I think you could use a single output unit in the last layer via:

model.fc = nn.Linear(fc_in_features, 1)

and use nn.MSELoss directly by providing the target in the shape [batch_size, 1] containing the class indices.
As you said, you would usually use e.g. nn.CrossEntropyLoss, but if your target distances are ordered, this approach might work for you.

Hi @ptrblck , thank you. This works.

However, do you know how I can attain the prediction values back from the regression model given that the targets are actually categories in my case? Would merely rounding the output of the of the regression model work?

For instance, my targets are [1,2,4,2,2,1,3,1] as my model output is [0.23, -0.12, 0.13, 0.17, 0.71, -0.26, 0.45, 0.37]. Do you know how I can attain the predictions back in the form of the targets?

I would also try to round these values to the next neighboring class index.
Alternatively, you could try to find thresholds using ROCs, but as explained before, your use case is rather unusual so you would most likely have to write a custom approach.