How do I calculate the loss when samples per batch have different shapes?

I have a training function like so:

def training():
    train_mae = []
    progress = tqdm(train_dataloader, desc='Training')
    for batch_index, batch in enumerate(progress):
        x = batch['x'].to(device)
        x_lengths = batch['x_lengths'].to(device)
        y = batch['y'].to(device)
        y_type = batch['y_type'].to(device)
        y_valid_indices = batch['y_valid_indices'].to(device)

        # Zero Gradients

        # Forward pass
        y_first, y_second = model(x)

        losses = []

        for j in range(len(x_lengths)):
            x_length = x_lengths[j].item()

            if y_type[j].item() == 0:
                predicted = y_first[j]
                predicted = y_second[j]

            actual = y[j]
            valid_mask = torch.zeros_like(predicted, dtype=torch.bool)
            valid_mask[:x_length] = 1
            # Padding of -1 is removed from y
            indices_mask = y[j].ne(-1)
            valid_indices = y[j][indices_mask]

            valid_predicted = predicted[valid_mask]
            valid_actual = actual[valid_mask]
            loss = mae_fn(valid_predicted, valid_actual, valid_indices)

        # Backward pass and update
        loss = torch.stack(losses).mean()   # This fails due to different shapes


            f"mae: {loss.detach().cpu().numpy():.4f}"

    # Return the average MAEs for y type
    return (
def mae_fn(output, target, indices):
    clipped_target = torch.clip(target, min=0, max=1)
    maes = F.l1_loss(output, clipped_target, reduction='none')
    return maes[indices]

Obviously can’t stack these losses since they have different shape due to the indices. Taking mean on maes[indices] will solve the issue, but it’s resulting in very bad test loss. What do I to calculate the loss here since indices determine the shape depending on y_type.

How do I calculate loss here?

Hi if you can explain why you have different samples per batch it would be helpful. either way I have made some assumptions and given something.

It would be nice to know what you are training here but looking at the code I am assuming this is a segmentation model, and if my understanding is right the only reason you are iterating through the batch elements is to remove the padding if this is the case then whether your target or label is lacking in size just add padding to that and calculate the loss.
Then generate a BOOLEAN MASK using this mask you can select the loss you want to select using the masked select function and then you can mean or sum over the batches.

import torch
loss = torch.randn((7,24,24))
mask =
loss_to_backward = torch.masked_select(loss, mask)