I’m currently working on a task wherein I need to combine a large number of images into a single image. I do this in two steps: first passing X images into N segmentation models each, and then passing the N results into another segmentation model to create the final image. My current issues are how to create the N intermediate images, and how to deal with missing data.
The code fragment below shows what I currently have working. In this example x is a list of tensors, where each tensor represents X images. Each tensor contains a different set of X images. Due to the nature of the dataset, it is not always possible to create N tensors. This code works, but only when I set the batch size to 1 and replace missing images with tensors filled with zeros.
def training_step(self, batch, batch_idx): x, y = batch segmented_images_list =  for current_image in x: intermediate_result = self.model_1(current_band) segmented_bands_list.append(result) intermediate_tensor = torch.cat(segmented_images_list , 0).permute(1, 0, 2, 3) y_hat = self.model_2(intermediate_tensor ) loss = self.loss_function(y_hat, y) self.log("train/loss", loss) return loss
I have two questions:
- Is there a more elegant solution than iterating over tensors, ideally one that would let me use a higher batch size than 1.
- How should I handle missing data? The second segmenter has a fixed number of input channels, meaning I must always supply N tensors.
Thanks for reading