Pix2Pix visualizer error for 6 input and 3 output channels

I’m trying to train pix2pix to go from a combination of 2 RGB images (6 input channels) to 1 RGB image (3 output channels). My dataset looks like this (same portion of the sky in optical, ultraviolet and infrared (false-coloured), respectively):


Setting --input_nc 6 and modifying getitem in aligned_dataset to be able to input 2 images (6 channels) like this:

        w, h = AB.size
        w3 = int(w / 3)
        A = AB.crop((0, 0, w3, h))
        B = AB.crop((w3, 0, w3*2, h))
        C = AB.crop((w3*2, 0, w, h))

        # apply the same transform to both A and B
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)
        C = B_transform(C)
        B = torch.cat((B, C))
        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}

I get the following error after epoch 1:

(epoch: 1, iters: 100, time: 1.336, data: 0.211) G_GAN: 1.438 G_L1: 2.366 D_real: 0.543 D_fake: 0.663 
(epoch: 1, iters: 200, time: 1.338, data: 0.005) G_GAN: 0.956 G_L1: 1.331 D_real: 0.871 D_fake: 0.448 
(epoch: 1, iters: 300, time: 1.338, data: 0.003) G_GAN: 0.834 G_L1: 2.449 D_real: 0.504 D_fake: 0.583 
/usr/local/lib/python3.7/dist-packages/visdom/__init__.py:366: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  return np.array(a)
Traceback (most recent call last):
  File "train.py", line 57, in <module>
    visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
  File "/content/pytorch-CycleGAN-and-pix2pix/util/visualizer.py", line 154, in display_current_results
    padding=2, opts=dict(title=title + ' images'))
  File "/usr/local/lib/python3.7/dist-packages/visdom/__init__.py", line 389, in wrapped_f
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/visdom/__init__.py", line 1292, in images
    height = int(tensor.shape[2] + 2 * padding)
IndexError: tuple index out of range

Any idea if other modifications are needed beyond the ones in the custom dataset and what they would be? Thanks a lot for any assistance.

visdom raises the error as apparently tensor.shape[2] isn’t valid.
I guess you are trying to visualize some outputs via model.get_current_visuals() and might need to check what exactly is returned and which shape it has.

Thank you for your answer. With some additional insight from the creator of the code I was able to modify a function called tensor2im that seemed to be the source of the problem. For anyone with the same issue I just included the following at the same indentation level as # grayscale to RGB in tensor2im in util.py:

        if image_numpy.shape[0] == 6:  
            image_numpy, b = np.vsplit(image_numpy, 2)

This will just split the two input images that were concatenated to feed the network and return one of them for visualization as real A for reference.