Per class loss and accuracy in U-Net

I’ve already seen How to find individual class accuracy. I am trying to get the same output (class accuracy) as in this. In my dataset I have three classes: Background, Class 1, Class 2. Each is mapped to 0, 1, 2. Currently use the following code to calculate the accuracy for the whole output:

def multi_acc(pred, label):
    tags = torch.argmax(pred, dim=1)
    corrects = (tags == label).float()
    acc = corrects.sum() / corrects.numel()
    acc = acc * 100
    return acc

How can I modify the above to get the following output:

Class 1 loss: x, Class 1 Accuracy: y
Class 2 loss: x, Class 2 Accuracy: y
Class 3 loss: x, Class 3 Accuracy: y

Full code for reference:

def train_net(
    net,
    n_channels,
    n_classes,
    class_weights,
    epochs=1,
    val_precent=0.1,
    batch_size=1,
    lr=0.0001,
    weight_decay=1e-8,
    momentum=0.99,
):
    print("Creating dataset for training...")
    dataset = Loader(data_folder)
    n_val = int(len(dataset) * val_precent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(
        train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
    )
    val_loader = DataLoader(
        val, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
    )

    global_step = 0

    # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.99, weight_decay=0.0005)
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
    # optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(
        weight=torch.Tensor(class_weights).to(device=device)
    )
    if wandb_track:
        wandb.watch(net)
    for epoch in range(epochs):
        net.train()
        tepoch_loss = 0
        tepoch_acc = 0
        vepoch_loss = 0
        vepoch_acc = 0

        once = True
        for batch in train_loader:
            imgs = batch["image"]
            masks = batch["mask"]
            assert imgs.shape[1] == n_channels, (
                f"Network has been defined with {n_channels} input channels, "
                f"but loaded images have {imgs.shape[1]} channels. Please check that "
                "the images are loaded correctly."
            )
            imgs = imgs.to(device=device, dtype=torch.float32)
            masks = masks.to(device=device, dtype=torch.long)

            masks_pred = net(imgs)
            loss = criterion(masks_pred, masks.squeeze(1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tepoch_loss += loss.item()
            tepoch_acc += multi_acc(masks_pred, masks)

            global_step += 1

        net.eval()
        for batch in val_loader:
            with torch.no_grad():
                imgs = batch["image"]
                masks = batch["mask"]
                imgs = imgs.to(device=device, dtype=torch.float32)
                masks = masks.to(device=device, dtype=torch.long)

                masks_pred = net(imgs)

                loss = criterion(masks_pred, masks.squeeze(1))

                vepoch_loss += loss.item()
                vepoch_acc += multi_acc(masks_pred, masks)

        tepoch_loss /= n_train
        tepoch_acc /= n_train
        vepoch_loss /= n_val
        vepoch_acc /= n_val

        print(
            "Epoch {0:} finished, Training loss: {1:.4f} [{2:.2f}%]  Validation loss: {3:.4f} [{4:.2f}%]".format(
                epoch + 1, tepoch_loss, tepoch_acc, vepoch_loss, vepoch_acc
            )
        )
        if wandb_track:
            wandb.log({"Test Accuracy": tepoch_acc, "Test Loss": tepoch_loss})
            wandb.log(
                {"Validation Accuracy": vepoch_acc, "Validation Loss": vepoch_loss}
            )
    try:
        os.mkdir(model_path)
    except OSError:
        pass

    torch.save(net.state_dict(), model_path + model_name)
    if wandb_track:
        torch.save(net.state_dict(), os.path.join(wandb.run.dir, model_name))


def multi_acc(pred, label):
    tags = torch.argmax(pred, dim=1)
    corrects = (tags == label).float()
    acc = corrects.sum() / corrects.numel()
    acc = acc * 100
    return acc

So it isn’t very efficient (just like the code you linked) but:

def multi_acc(pred, label):
  accs_per_label_pct = []
  tags = torch.argmax(pred, dim=1)
  for c in range(3):  # the three classes
    of_c = label == c
    num_total_per_label = of_c.sum()
    of_c &= tags == label
    num_corrects_per_label = of_c.sum()
    accs_per_label_pct.append(num_corrects_per_label / num_total_per_label * 100)
  return accs_per_label_pct

Note that you don’t need to cast to float here with non-ancient PyTorch.

For the loss, you can use reduction=‘none’ to get the elementwise loss. Then you can again filter by the label.

Best regards

Thomas

Thanks tom, any tips on making it more efficient? what should I improve/change?

I’d not worry about it before it becomes your bottleneck.

Hey Tom, a question regarding the loss per class: when I change the loss to use reduction='none' I get a tensor with the same values as the target (1,512,512), I didn’t understand how the filtering is done to acquire the loss per class, as far as I know, I am supposed to take the mean of each dimension for each class (?). Also say I get the loss per class, I still need to compute the mean of the loss using the custom weights that I have prepared, any idea on how to do this efficiently? my idea was to use two loss variables, one with reduction set to mean to calculate the loss as per usual and another variable with the reduction set to none for class loss calculation…