How to train models that have a many-to-one relation

I’m currently working on a task wherein I need to combine a large number of images into a single image. I do this in two steps: first passing X images into N segmentation models each, and then passing the N results into another segmentation model to create the final image. My current issues are how to create the N intermediate images, and how to deal with missing data.

The code fragment below shows what I currently have working. In this example x is a list of tensors, where each tensor represents X images. Each tensor contains a different set of X images. Due to the nature of the dataset, it is not always possible to create N tensors. This code works, but only when I set the batch size to 1 and replace missing images with tensors filled with zeros.

    def training_step(self, batch, batch_idx):
        x, y = batch

        segmented_images_list = []

        for current_image in x:
            intermediate_result = self.model_1(current_band)
            segmented_bands_list.append(result)

        intermediate_tensor = torch.cat(segmented_images_list , 0).permute(1, 0, 2, 3)

        y_hat = self.model_2(intermediate_tensor )
        loss = self.loss_function(y_hat, y)
        self.log("train/loss", loss)

        return loss

I have two questions:

  1. Is there a more elegant solution than iterating over tensors, ideally one that would let me use a higher batch size than 1.
  2. How should I handle missing data? The second segmenter has a fixed number of input channels, meaning I must always supply N tensors.

Thanks for reading :slight_smile: