RuntimeError "but got 2 channels instead"

Hello,

I modified Pix2Pix in order to have a loss style instead of an L1 loss (for theoretical reasons). When I test my loss style between two images of my training set, no problem, it works well but in my network I get the error RuntimeError: Given groups=1, weight of size [64, 3, 3, 3, 3], expected input [1, 2, 256, 256] to have 3 channels, but got 2 channels instead. This is an error that I understand very well but I don’t see where it could come from.

import torch
import torchvision.models as models
from .base_model import BaseModel
from . import networks
from util.util import gram_matrix, get_features, style_loss



class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

    The model training requires '--dataset_mode aligned' dataset.
    By default, it uses a '--netG unet256' U-Net generator,
    a '--netD basic' discriminator (PatchGAN),
    and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For pix2pix, we do not use image buffer
        The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
        By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
        """
        # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
        parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
        if is_train:
            parser.set_defaults(pool_size=0, gan_mode='vanilla')
            parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')

        return parser

    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        self.vgg = models.vgg19(pretrained=True).features
        self.vgg.to(self.device)


        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load G
            self.model_names = ['G']
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionStyleLoss = lambda style_grams, target_features : style_loss(style_grams, target_features)
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        style_features = get_features(self.real_B, self.vgg)
        style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
        target = self.fake_B.requires_grad_(True).to(self.device)
        target_features = get_features(self.real_B, self.vgg)
        self.loss_G_StyleLoss = self.criterionStyleLoss(style_grams, target_features) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_StyleLoss
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()                   # compute fake images: G(A)
        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.optimizer_D.zero_grad()     # set D's gradients to zero
        self.backward_D()                # calculate gradients for D
        self.optimizer_D.step()          # update D's weights
        # update G
        self.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing G
        self.optimizer_G.zero_grad()        # set G's gradients to zero
        self.backward_G()                   # calculate graidents for G
        self.optimizer_G.step()             # udpate G's weights

And then the functions in util that I use to retrieve the features :

def gram_matrix(tensor):

  # get batch_size, depth, height, width of tensor
  _, d, h, w = tensor.size()
  # reshape so we are multiplying heightand width
  tensor = tensor.view(d, h * w)
  # calc. gram matrix
  gram = torch.mm(tensor, tensor.t())

  return gram

def get_features(image, model, layers = None):
  """ On passe l'image dans le modèle et obtient des features pour les layers qu'on précise
  """

  if layers is None:
    layers = {'0' : 'conv1_1',
              '5' : 'conv2_1',
              '10' : 'conv3_1',
              '19': 'conv4_1',
              '21': 'conv4_2',
              '28': 'conv5_1'}

  features = {}
  x = image
  for name, layer in model._modules.items():
    x = layer(x)   #passing image through layer
    if name in layers:
      features[layers[name]] = x

  return features

global style_weights
style_weights = {'conv1_1' : 1,
                 'conv2_1' : 0.75,
                 'conv3_1': 0.2,
                 'conv4_1' : 0.2,
                 'conv5_1' : 0.2}

def style_loss(style_grams, target_features):
    loss = 0

    for layer in style_weights:
      target_feature = target_features[layer]
      target_gram = gram_matrix(target_feature)
      _, d, h, w = target_feature.shape
      style_gram = style_grams[layer]
      layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
      loss += layer_style_loss / (d * h * w)

    return loss

Thank you in advance for your help

Based on the weight shape I guess the error is raised by an nn.Conv3d layer. Could you check where such a layer is used in this config and make sure the input to it has 3 valid channels?

yes thank you that was it !