Predicting parent class based on the prediction of subclasses

I work on an action recognition task and I want based on the prediction of subclasses ‘action_name’ and ‘priority’ to predict the parent class ‘Diagnosis’
the script I used was inspired by this reference
I changed tho model implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


class MultiOutputModel(nn.Module):
    def __init__(self, n_action_classes, n_priority_classes, n_diagnosis_classes):
        super().__init__()
        self.base_model = models.mobilenet_v2().features  # take the model without classifier
        last_channel = models.mobilenet_v2().last_channel  # size of the layer before classifier

        # the input for the classifier should be two-dimensional, but we will have
        # [batch_size, channels, width, height]
        # so, let's do the spatial averaging: reduce width and height to 1
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # create separate classifiers for our outputs
        self.action = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_action_classes)
        )
        self.priority = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_priority_classes)
        )
        self.diagnosis = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=n_action_classes + n_priority_classes, out_features=n_diagnosis_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, 1)
        
        # Subclass predictions
        action = self.action(x)
        priority = self.priority(x)
        
        # Concatenate subclass outputs for parent class prediction
        combined_action_priority_outputs = torch.cat([action, priority], dim=1)
        diagnosis = self.diagnosis(combined_action_priority_outputs)
        
        return {
            'action': action,
            'priority': priority,
            'diagnosis': diagnosis
        }

    def get_loss(self, net_output, ground_truth):
        
        action_loss = F.cross_entropy(net_output['action'], ground_truth['action_labels'])
        priority_loss = F.cross_entropy(net_output['priority'], ground_truth['priority_labels'])
        diagnosis_loss = F.cross_entropy(net_output['diagnosis'], ground_truth['diagnosis_labels'])
        
        loss = action_loss + priority_loss + diagnosis_loss
        
        return loss, {'action': action_loss, 'priority': priority_loss, 'diagnosis': diagnosis_loss}
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


class MultiOutputModel(nn.Module):
    def __init__(self, n_action_classes, n_priority_classes, n_diagnosis_classes):
        super().__init__()
        self.base_model = models.mobilenet_v2().features  # take the model without classifier
        last_channel = models.mobilenet_v2().last_channel  # size of the layer before classifier

        # the input for the classifier should be two-dimensional, but we will have
        # [batch_size, channels, width, height]
        # so, let's do the spatial averaging: reduce width and height to 1
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # create separate classifiers for our outputs
        self.action = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_action_classes)
        )
        self.priority = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_priority_classes)
        )
        self.diagnosis = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=n_action_classes + n_priority_classes, out_features=n_diagnosis_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, 1)
        
        # Subclass predictions
        action = self.action(x)
        priority = self.priority(x)
        
        # Concatenate subclass outputs for parent class prediction
        combined_action_priority_outputs = torch.cat([action, priority], dim=1)
        diagnosis = self.diagnosis(combined_action_priority_outputs)
        
        return {
            'action': action,
            'priority': priority,
            'diagnosis': diagnosis
        }

    def get_loss(self, net_output, ground_truth):
        
        action_loss = F.cross_entropy(net_output['action'], ground_truth['action_labels'])
        priority_loss = F.cross_entropy(net_output['priority'], ground_truth['priority_labels'])
        diagnosis_loss = F.cross_entropy(net_output['diagnosis'], ground_truth['diagnosis_labels'])
        
        loss = action_loss + priority_loss + diagnosis_loss
        
        return loss, {'action': action_loss, 'priority': priority_loss, 'diagnosis': diagnosis_loss}

the results gave me the value 1 accuracy , am I addressing the problem the right way !!

Could you share how you’ve computed the accuracy and also post the model output shapes as well as the shapes of all targets?

the is the training accuracy calculation

 for epoch in range(start_epoch, N_epochs + 1):
      
     total_loss = 0
     accuracy_action = 0
     accuracy_priority = 0
     accuracy_diagnosis = 0

     print(f"Epoch {epoch} started...")

     for batch in train_dataloader:
        #print(f"Processing batch...")
        optimizer.zero_grad()

        img = batch['img'].to(device)
        target_labels = {t: batch['labels'][t].to(device) for t in batch['labels']}
        #forwardpass
        output = model(img.to(device))
        action_output = output['action']
        priority_output = output['priority']
        diagnosis_output = output['diagnosis']
        #calculate_loss
        loss_train, losses_train = model.get_loss(output, target_labels)
        
        total_loss += loss_train.item()
        
        batch_accuracy_action, batch_accuracy_priority, batch_accuracy_diagnosis = \
            calculate_metrics(output, target_labels)

        accuracy_action += batch_accuracy_action
        accuracy_priority += batch_accuracy_priority
        accuracy_diagnosis += batch_accuracy_diagnosis

and this is the validate accuracy calculation

     
def validate(model, dataloader, logger, iteration, device, checkpoint=None):
    if checkpoint is not None:
        checkpoint_load(model, checkpoint)

    model.eval()
    with torch.no_grad():
        avg_loss = 0
        accuracy_action = 0
        accuracy_priority = 0
        accuracy_diagnosis = 0

        for batch in dataloader:
            img= batch['img']
            target_labels = batch['labels']
            target_labels = {t: target_labels[t].to(device) for t in target_labels}
            output = model(img.to(device))

            val_train, val_train_losses = model.get_loss(output, target_labels)
            avg_loss += val_train.item()
            batch_accuracy_action, batch_accuracy_priority, batch_accuracy_diagnosis = \
                calculate_metrics(output, target_labels)

            accuracy_action += batch_accuracy_action
            accuracy_priority += batch_accuracy_priority
            accuracy_diagnosis += batch_accuracy_diagnosis

    n_samples = len(dataloader)
    avg_loss /= n_samples
    accuracy_action /= n_samples
    accuracy_priority /= n_samples
    accuracy_diagnosis /= n_samples
    print('-' * 72)
    print("Validation  loss: {:.4f}, action: {:.4f}, priority: {:.4f}, diagnosis: {:.4f}\n".format(
        avg_loss, accuracy_action, accuracy_priority, accuracy_diagnosis))

    logger.add_scalar('val_loss', avg_loss, iteration)
    logger.add_scalar('val_accuracy_action', accuracy_action, iteration)
    logger.add_scalar('val_accuracy_priority', accuracy_priority, iteration)
    logger.add_scalar('val_accuracy_diagnosis', accuracy_diagnosis, iteration)


    model.train()

the model output shape :

Action output shape: torch.Size([16, 3])
Priority output shape: torch.Size([16, 3])
Diagnosis output shape: torch.Size([16, 2])

inputs shape : torch.Size([16, 3, 224, 224])

targets shape

action_targets shape : torch.Size([16])
priority_targets shape : torch.Size([16])
diagnosis_targets shape : torch.Size([16])

could it be that the accuracy of diagnosis is computed independantly from the two other !

Thank you! Could you also post the calculate_metrics method?

   def calculate_metrics(output, target):
    _, predicted_action = output['action'].cpu().max(1)
    gt_color = target['action_labels'].cpu()

    _, predicted_priority = output['priority'].cpu().max(1)
    gt_priority = target['priority_labels'].cpu()

    _, predicted_diagnosis = output['diagnosis'].cpu().max(1)
    gt_diagnosis = target['diagnosis_labels'].cpu()

    with warnings.catch_warnings():  # sklearn may produce a warning when processing zero row in confusion matrix
        warnings.simplefilter("ignore")
        accuracy_action = balanced_accuracy_score(y_true=gt_color.numpy(), y_pred=predicted_action.numpy())
        accuracy_priority = balanced_accuracy_score(y_true=gt_priority.numpy(), y_pred=predicted_priority.numpy())
        accuracy_diagnosis = balanced_accuracy_score(y_true=gt_diagnosis.numpy(), y_pred=predicted_diagnosis.numpy())

    return accuracy_action, accuracy_priority, accuracy_diagnosis

Thanks! Unfortunately, your code is still not executable since balanced_accuracy_score is undefined. To further debug your issue, check the shapes of the used model outputs and targets and check if unintended broadcasting is used internally somewhere which could yield wrong outputs. You could also check if balanced_accuracy_score provides docs explaining what the expected input shapes and dtypes are and compare it to your code.

@ptrblck thank you for your reply,so the way i defined the model and claculates the losses is correct if I want predict a class based on the prediction of subclasses ?

I don’t know as the balanced_accuracy_score method is undefined and I thus recommended further debugging steps.