RuntimeError: Given groups=1, weight of size [16, 2, 5, 5], expected input[16, 5, 80, 80] to have 2 channels, but got 5 channels instead

I am building a GAN with input image of 4 channels and output of 1 channel. I have a Unet on the Generator side. Here is my code:

MEAN = (0.5, 0.5, 0.5, 0.5,)
STD = (0.5, 0.5, 0.5, 0.5,)
RESIZE =  80 #96 # 128  256
img_size = RESIZE

def argument_parser(img):
    ap = argparse.ArgumentParser()
    ap.add_argument("-i", "--image", type=str, default="opencv_logo.png", help="path to the input image")
    args = vars(ap.parse_args())


class Transform():
    def __init__(self, resize=RESIZE, mean=MEAN, std=STD):
        if resize> 128:
            self.data_transform = transforms.Compose([
            #transforms.Resize((resize*2, resize*2)), 
            transforms.Resize((resize, resize)),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
            ])
        else:
            self.data_transform = transforms.Compose([
            #transforms.Resize((resize*2, resize*2)), 
            #transforms.Resize((resize, resize)),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
            ])
                
        
    def __call__(self, img: Image.Image):
        return self.data_transform(img)

    
class Dataset(object):
    def __init__(self, files: List[str]):
        self.files = files 
        self.trasformer = Transform()
        
    def _separate(self, img) -> Tuple[Image.Image, Image.Image]:
        img = np.array(img, dtype=np.uint8)
        h, w, _ = img.shape
        w = int(w / 2)
        return Image.fromarray(img[:, w:, :]), Image.fromarray(img[:, :w, :])
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img = Image.open(self.files[idx])
        input, output = self._separate(img)
        input_tensor = self.trasformer(input)
        # Creating the output tensor with only the first channel from the input
        output_t = self.trasformer(output)
        output_tensor = output_t[0:1, :, :]
        
        print('input tensor shape :', input_tensor.shape)
        print('output tensor shape :', output_tensor.shape)
        return input_tensor, output_tensor 
    
    def __len__(self):
        return len(self.files)
    
def show_img_sample(img: torch.Tensor, img1: torch.Tensor):
    fig, axes = plt.subplots(1, 2, figsize=(15, 8))
    ax = axes.ravel()
    ax[0].imshow(img.permute(1, 2, 0))
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].set_title("input image", c="g")
    ax[1].imshow(img1.permute(1, 2, 0))
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    ax[1].set_title("label image", c="g")
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()
class DoubleConv(nn.Module):
    in_channels = 4 #5
    out_channels = 1 #5
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
        # 3 = kernel size
        # 1 = stride
        # 1 = padding
        # setting bias = False b/c we are using batch norm
        # will result in a same covolution such that the image size of the input = ouput size
        nn.Conv2d(in_channels,  out_channels, kernel_size=3, stride=1, padding=1, bias=True), #(SpectralNorm),
        nn.BatchNorm2d(out_channels), 
        #nn.ReLU(inplace=True),
        nn.LeakyReLU(0.2, inplace=True),    

        # second convolution
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias = True), #(SpectralNorm),
       
        nn.BatchNorm2d(out_channels),
        #nn.ReLU(inplace=True),
        nn.LeakyReLU(0.2, inplace=True),    
        )

    def forward(self, x):
        return self.conv(x)  # this calls the DoubleConv we defined above (2 Conv2d steps)
    
class UpConv(nn.Module):
    in_channels = 4 #5
    out_channels = 1
    def __init__(self, in_channels, out_channels):
        super(UpConv, self).__init__()
        self.conv = nn.Sequential(
        # 3 = kernel size
        # 1 = stride
        # 1 = padding
        # setting bias = False b/c we are using batch norm
        # will result in a same covolution such that the image size of the input = ouput size
        nn.Conv2d(in_channels,  out_channels, kernel_size=3, stride=1, padding=1, bias=True), #(SpectralNorm),
        nn.BatchNorm2d(out_channels), 
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),   

        # second convolution
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias = True), #(SpectralNorm),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        )

    def forward(self, x):
        return self.conv(x)  # this calls the DoubleConv we defined above (2 Conv2d steps)
class Generator(nn.Module):
#class UNET(nn.Module):
    in_channels = 4 #5
    out_channels = 1 #4 #5

    def __init__(self, in_channels=in_channels, out_channels=out_channels, features=[64,128, 256, 512,],):
        #super(UNET, self).__init__()
        super(Generator, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of the UNET
        
        self.downs1=DoubleConv(in_channels, 64)
        self.downs2=DoubleConv(64, 128)
        self.downs3=DoubleConv(128, 256)
        self.downs4=DoubleConv(256, 512)
        
        #self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.bottleneck = DoubleConv(512, 1024)
    
        
        self.ups4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2,padding=0) #(SpectralNorm( ))
        
        self.ups_conv4=UpConv(1024, 512)
        
        self.ups3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2,padding=0) #(SpectralNorm)
       
        self.ups_conv3=UpConv(512, 256)
        
        self.ups2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2,padding=0) #(SpectralNorm)
      
        self.ups_conv2=UpConv(256, 128)
        
        self.ups1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2,padding=0) #(SpectralNorm)
        
        self.ups_conv1=UpConv(128, 64)
        
        '''for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2,))
            self.ups.append(DoubleConv(feature*2, feature))'''

        

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) 

    def forward(self, x):
        skip_connections = []
        x1 = self.downs1(x)
        skip_connections.append(x1)
        x1 = self.pool(x1)
        
        x2 = self.downs2(x1)
        skip_connections.append(x2)
        x2 = self.pool(x2)

        x3 = self.downs3(x2)
        skip_connections.append(x3)
        x3 = self.pool(x3)

        x4 = self.downs4(x3)
        skip_connections.append(x4)
        x4 = self.pool(x4)
        '''for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)'''

        x5 = self.bottleneck(x4)

        
        d6 = self.ups4(x5)
    
        skip_connection4 = skip_connections[3] 
        print("Size of skip_connection4:", skip_connection4.size())
        print("Size of d6 before concat:", d6.size())
        d6 = torch.cat((skip_connection4, d6), dim=1)
        print("Size of d6 after concat:", d6.size())
        d6 = self.ups_conv4(d6)
        
        d7 = self.ups3(d6)
        skip_connection3 = skip_connections[2] # x3
      
        print("Size of skip_connection3:", skip_connection3.size())
        print("Size of d7 before concat:", d7.size())
        d7 = torch.cat((skip_connection3, d7), dim=1)
        print("Size of d7 after concat:", d7.size())
        d7 = self.ups_conv3(d7)

        
        d8 = self.ups2(d7)

        skip_connection2 = skip_connections[1] # x4

        print("Size of skip_connection2:", skip_connection2.size())
        print("Size of d8 before concat:", d8.size())
        d8 = torch.cat((skip_connection2, d8), dim=1)
        print("Size of d8 after concat:", d8.size())
        d8 = self.ups_conv2(d8)

        
        d9 = self.ups1(d8)
        
        skip_connection1 = skip_connections[0] 
        print("Size of skip_connection1:", skip_connection1.size())
        print("Size of d9 before concat:", d9.size())
        d9 = torch.cat((skip_connection1, d9), dim=1)
        print("Size of d9 after concat:", d9.size())
        d9 = self.ups_conv1(d9)
       

        return torch.tanh(self.final_conv(d9))
    
      

# write a test of the UNET

def test():
    in_channels = 4 #5
    out_channels = 1# 5

    #x = torch.randn((10, 5, 160, 160))  # batch size = 3, channels = 1, image = 160 x 160 (divisable by 16)
    x = torch.randn((10, 4, 64, 64)) 
    #model = UNET(in_channels = in_channels, out_channels=out_channels)
    model = Generator(in_channels = in_channels, out_channels=out_channels)
    preds = model(x)
    print(preds.shape)
    print(x.shape)
    
    # assert preds.shape == x.shape

if  __name__ == "__main__":
    test()

class Discriminator(nn.Module):
    def __init__(self, in_channels=1,out_channels=4,hidden_channels=32): 
        super(Discriminator, self).__init__()
        self.layer1 = self.conv2relu(in_channels*2, hidden_channels//2, 5, cnt=1) 
        
        self.layer2 = self.conv2relu(hidden_channels//2, hidden_channels, pool_size= None) 
        
        self.layer3 = self.conv2relu(hidden_channels, hidden_channels*2, pool_size=None)
    
        self.layer4 = self.conv2relu(hidden_channels*2, hidden_channels*4, pool_size=None)
       
        self.layer5 = self.conv2relu(hidden_channels*4, hidden_channels*8, pool_size=None)
        
        self.layer6 = self.conv2relu(hidden_channels*8, hidden_channels*16, pool_size=None)
       
        self.layer7 = self.conv2relu(hidden_channels*16, hidden_channels*32, pool_size=None)
        
        self.layer8 = nn.Conv2d(hidden_channels*32, 1, kernel_size=1) #()SpectralNorm(
        
    def conv2relu(self, in_c, out_c, kernel_size=3, pool_size=None, cnt=2):
        layer = []
        for i in range(cnt):
            if i == 0 and pool_size != None:
                # Down width and height 
                layer.append(nn.AvgPool2d(pool_size)) #(SpectralNorm)
            # Down channel size 
            layer.append(nn.Conv2d(in_c if i == 0 else out_c, # (SpectralNorm
                                   out_c,
                                   kernel_size,
                                   padding=(kernel_size-1)//2))
            layer.append(nn.BatchNorm2d(out_c))
            layer.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layer)
        
    def forward(self, x, x1):
        x = torch.cat((x, x1), dim=1)
        
        out1 = self.layer1(x)
        
        
        out2 =    self.layer2(out1)

        
        out3 =    self.layer3(out2)
        
        out4 =    self.layer4(out3)
        
        out5 =    self.layer5(out4)
       
        
        out6 =    self.layer6(out5)
       
        
        out7 =    self.layer7(out6)
        
        return torch.sigmoid(self.layer8(out7))
    

def test_D():
    in_channels = 4 
    out_channels = 1

    #x = torch.randn((10, 5, 160, 160))  # batch size = 3, channels = 1, image = 160 x 160 (divisable by 16)
    x = torch.randn((1, 4, 30, 30)) 
    x1 = torch.randn((1, 4, 30, 30))
    #model = UNET(in_channels = in_channels, out_channels=out_channels)
    model = Discriminator(in_channels = in_channels, out_channels=out_channels)
    preds = model(x, x1)
    print(f'pred tensor dimension: {preds.shape}')
    print(f' image dimension: {x.shape} ')
    #assert preds.shape == x.shape

if  __name__ == "__main__":
    test_D()

import re
import random
current_epoch = 0
# num_epoch = 2
# batch=2
# critics = 3

#def train_fn(train_dl, G, D, criterion_bce, criterion_smoothL1, optimizer_g, optimizer_d):
def train_fn(train_dl, G, D, GAN_loss, L1_loss, optimizer_g, optimizer_d):
#for epoch in range (current_epoch, num_epoch):
    G.train() #.to(device)
    D.train() #.to(device)
    LAMBDA_penalty = 10.0
    lambda_pixel = 10.0
    total_loss_g, total_loss_d = [],[]
    #for i, batch in enumerate(tqdm(train_dl)):
        #input_img = Variable(batch[0].type(Tensor))
        #real_img = Variable(batch[1].type(Tensor))
    for i, (input_img, real_img) in enumerate(tqdm(train_dl)):
        real_A = input_img.to(device)
        real_B = real_img.to(device)
        
        # ------------------
        #  Train Generators
        # ------------------
        optimizer_g.zero_grad()

        # GAN loss
        fake_B = G(real_A)
        pred_fake_g = D(fake_B, real_A).squeeze() # ROkwen added .squeeze()
        
        pred_real_label = D(real_B, real_A).squeeze() 
        
        # Adversarial ground truths
        # Allow real image label to range between 0.8 and 0.9
        real_label = random.randrange(9, 10,1)/9.0
        real_label = torch.tensor(real_label)
        real_label = real_label.expand_as(pred_real_label).to(device)

        #real_label = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        #print(f' real label value: {real_label.mean()}')
        
        fake_label = torch.tensor(0.0) 
        fake_label = fake_label.expand_as(pred_fake_g).to(device)
        #print(f' fake label value: {fake_label.mean()}')
        #####
        
        # Adversarial loss
        #loss_g_adv = GAN_loss(pred_fake_g, real_label) # nn.BCEWithLogitsLoss()
        loss_g_adv = LSGAN_loss(pred_fake_g, real_label) # nn.MSELoss()
        
        #print("Adversarial loss:", loss_g_adv)
        
        # Dice loss
        #loss_g_dice = calc_loss(fake_B, real_B) 
        #loss_g_adv = loss_g_dice
        #print(f'Dice loss {loss_g_dice}')
        
        # Hinge loss
        #loss_g_adv = - pred_fake_g.mean()
        #print(f'Ganloss {loss_g_adv}')
        # Pixel-wise loss
        loss_pixel = L1_loss(fake_B, real_B)
        #loss_pixel = L1_loss(torch.tanh(fake_B), real_B)
        #torch.tanh(self.final_conv(d9))
        #print("pixel loss:", loss_pixel)
        #print(f'L1 loss {loss_pixel}')

        # Total loss
        loss_g = loss_g_adv + lambda_pixel * loss_pixel
        #loss_g = loss_g_dice
        
        #print(f'Total GAN loss {loss_g}')
        total_loss_g.append(loss_g.item())
        
        loss_g.backward()
        optimizer_g.step()
        
        #optimizer_d.zero_grad()
        #optimizer_g.zero_grad()
        
        # ------------------
        #  Train Discriminator
        # ------------------
        for _ in range (n_critic):
            
            optimizer_d.zero_grad()
            # Calculate Discriminator Losses 
        
            # Real loss
            pred_real = D(real_B, real_A).squeeze() # ROkwen added .squeeze()
        
            #real_img = input_img.to(device)
            #input_img = real_img.to(device)
        
            loss_d_real = GAN_loss(pred_real, real_label)
            #print("real D loss:", loss_d_real)
            #loss_d_real = torch.nn.ReLU()(1.0 - pred_real).mean()
            #print(f'real discriminator loss {loss_d_real}')
        
            # Fake loss
            #fake_B = G(real_A)
            pred_fake_d = D(fake_B.detach(), real_A).squeeze() # ROkwen added .squeeze()
            loss_d_fake = GAN_loss(pred_fake_d, fake_label)
            #print("fake D loss:", loss_d_fake)
            #print(f'fake discriminator loss {loss_d_fake}')
        
            loss_d = (loss_d_real + loss_d_fake)*0.5  # Adversarial D loss # #
        
            #loss_d = loss_d_real + loss_d_fake 
            total_loss_d.append(loss_d.item())

            #print(f'Discriminator loss: {loss_d}')

            # Backward + Optimize

            loss_d.backward()
            optimizer_d.step()

        # clear_output()
        
        real_A = torch.squeeze(real_A, 0)
        real_B = torch.squeeze(real_B, 0)
        fake_B = torch.squeeze(fake_B, 0)
        
        # 2, 4, 80, 80
        real_A = real_A[0,:,:,:]
        real_B = real_B[0,0,:,:]
        fake_B = fake_B[0,0,:,:]
        print(f'target images  shape 2 {real_A.shape}')
        '''print(f'target images  shape 2 {real_A.shape}')
        
        real_A = torch.reshape(real_A, (real_A.shape[0],real_A.shape[2],real_A.shape[2])).detach().cpu()
        #print(f'input_img images  shape {real_A.shape}')
        fake_B = torch.reshape(fake_B, (fake_B.shape[0],fake_B.shape[2],fake_B.shape[2])).detach().cpu()
        real_B = torch.reshape(real_B, (real_B.shape[0],real_B.shape[2],real_B.shape[2])).detach().cpu()

        #print(f'target images  shape 2 {real_B.shape}')
        show_img_sample3(real_A, fake_B, real_B)
        '''
        #print(f'target images  shape 2 {fake_B.shape[1:]}')
        #fake_B = torch.reshape(fake_B, (fake_B.shape[1],fake_B.shape[2],fake_B.shape[2])).detach().cpu()
        #real_B = torch.reshape(real_B, (real_B.shape[1],real_B.shape[2],real_B.shape[2])).detach().cpu()
        
        #show_img_sample_training(fake_B,real_B)
        #print(f'gradient penalty loss: {gradient_penalty}')
        #print(f'Discriminator loss: {loss_d}')
        
    scheduler_g.step()
    scheduler_d.step()
    lr_D = scheduler_d.get_last_lr()[0]
    lr_G = scheduler_g.get_last_lr()[0] 
     
    #lr_D = lr
    #lr_G = lr
    #return mean(total_loss_g), mean(total_loss_d), input_img.detach().cpu(), fake_img.detach().cpu(), real_img.detach().cpu(), lr_D, lr_G #.to(device) #.cpu() 
    return mean(total_loss_g), mean(total_loss_d), real_A.detach().cpu(), fake_B.detach().cpu(), real_B.detach().cpu(), lr_D, lr_G #.to(device) #.cpu() 
    #return mean(total_loss_g), mean(total_loss_d), real_A.detach().cpu(), fake_B.detach().cpu(), real_B.detach().cpu() #.to(device) #.cpu() 


def saving_img(input_img, fake_img, real_img, e):
    os.makedirs("generated", exist_ok=True)
    save_image(input_img, f"generated/input{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True) 
    save_image(fake_img, f"generated/fake{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True) 
    save_image(real_img, f"generated/real{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True) 
    # Changed 'range' to 'value_range' ==> 'range' will soon be deprecated.
    
def saving_logs(result):
    with open("train.pkl", "wb") as f:
        pickle.dump([result], f)
        
def saving_model(D, G, e):
    os.makedirs("weight", exist_ok=True)
    torch.save(G.state_dict(), f"weight/G{str(e+1)}.pth")
    torch.save(D.state_dict(), f"weight/D{str(e+1)}.pth")
        
def show_losses(g, d, rmse_T):
    fig, axes = plt.subplots(1, 3, figsize=(21,6)) # (14, 6)
    ax = axes.ravel()
    ax[0].plot(np.arange(len(d)).tolist(), d)
    ax[0].set_title("Discriminator Loss")
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('Loss')
    
    ax[1].plot(np.arange(len(g)).tolist(), g)
    ax[1].set_title("Generator Loss")
    ax[1].set_xlabel('Epochs')
    ax[1].set_ylabel('Loss')
    
    ax[2].plot(np.arange(len(rmse_T)).tolist(), rmse_T)
    ax[2].set_title("Root Mean Square Error")
    ax[2].set_xlabel('Epochs')
    ax[2].set_ylabel('RMSE')
    plt.show()

Both, test and test_D functions work fine using the random inputs and return:

test()
Size of skip_connection4: torch.Size([10, 512, 8, 8])
Size of d6 before concat: torch.Size([10, 512, 8, 8])
Size of d6 after concat: torch.Size([10, 1024, 8, 8])
Size of skip_connection3: torch.Size([10, 256, 16, 16])
Size of d7 before concat: torch.Size([10, 256, 16, 16])
Size of d7 after concat: torch.Size([10, 512, 16, 16])
Size of skip_connection2: torch.Size([10, 128, 32, 32])
Size of d8 before concat: torch.Size([10, 128, 32, 32])
Size of d8 after concat: torch.Size([10, 256, 32, 32])
Size of skip_connection1: torch.Size([10, 64, 64, 64])
Size of d9 before concat: torch.Size([10, 64, 64, 64])
Size of d9 after concat: torch.Size([10, 128, 64, 64])
torch.Size([10, 1, 64, 64])
torch.Size([10, 4, 64, 64])
test_D()
pred tensor dimension: torch.Size([1, 1, 30, 30])
 image dimension: torch.Size([1, 4, 30, 30]) 

I’ve formatted your code but note you can post code snippets by wrapping them into three backticks ```.

I have 4 channels in the input image and the label image is made up of 1 channel. They look like the following:


This implies that realB and fakeB are made up of 1 channel. Please point out any mistakes that are causing the error. Here is my training loop:
‘’’
def train_loop(train_dl, G, D, num_epoch, lr=0.0002, lr_D=0.0004, betas=(0.5, 0.999), lr_gamma=.99999):

G.to(device)
D.to(device)
optimizer_g = torch.optim.Adam(G.parameters(), lr=lr, betas=betas, weight_decay=3.0e-5)
#optimizer_d = torch.optim.Adam(D.parameters(), lr=lr, betas=betas)
optimizer_d = torch.optim.Adam(D.parameters(), lr=lr_D, betas=betas)
scheduler_g = StepLR(optimizer_g, step_size=1, gamma=lr_gamma)
scheduler_d = StepLR(optimizer_g, step_size=1, gamma=lr_gamma)

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

#L1_loss = nn.L1Loss()
# GAN_loss = nn.BCEWithLogitsLoss()
# ROkwen
#criterion_smoothL1 = torch.nn.SmoothL1Loss()
#criterion_smoothL1 = torch.nn.MSELoss()

# End
total_loss_d, total_loss_g = [], []
rmse_error = []
image_rmseT = []
ssim_T =[]
result = {}

saved_samples = []
start_time = time.time()
batches_output=1


groups = {'Loss': ['Discriminator loss', 'Generator loss'], 'Learning rate': ['Discriminator', 'Generator'], 'Metrics': ['Root Mean Square Error' ]} # 
#groups = {'Loss': ['Discriminator loss', 'Generator loss','Root Mean Square Error'], 'Learning rate': ['Discriminator', 'Generator' ]} # ,'SSIM'

liveloss = PlotLosses(groups=groups)
history={ }


for e in range(num_epoch):
    loss_g, loss_d, real_A, fake_B, real_B, lr_D, lr_G = train_fn(train_dl,G,D,L1_loss,GAN_loss, optimizer_g, optimizer_d)
    total_loss_d.append(loss_d)
    total_loss_g.append(loss_g)
    
    image_rmse = np.sqrt(mean_squared_error(real_B.view(1,-1), fake_B.view(1,-1)))
    image_rmseT.append(image_rmse)
    
    ##
    real_img = real_B.view(1,-1).numpy() #save_image(real_B, f"generated/input{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True)
    fake_img = fake_B.view(1,-1).numpy() #save_image(fake_B, f"generated/input{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True)
    
    print(real_img.shape)
    print(fake_img.shape)
    ##
    #ssim_img = ssim(real_img, fake_img, data_range=fake_img.max() - fake_img.min(), channel_axis = 0) #multichannel=True)
    #saving_img(input_img, fake_img, real_img, e+1)
    
    # ROkwen
    # --------------
    #  Log Progress
    # --------------
    
    history['Discriminator loss'] = loss_d
    history['Generator loss'] = loss_g
    history['Discriminator'] = lr_D
    history['Generator'] = lr_G 
    history['Root Mean Square Error'] = image_rmse 
    #history['SSIM'] =ssim_img

    # Determine approximate time left
    batches_done = e * len(train_dl) # + i
    batches_left = num_epoch * len(train_dl) - batches_done
    time_left = datetime.timedelta(seconds=batches_left * (time.time() - start_time) / (batches_done + 1))

    # Print log
    #"\r[Epoch %d/%d] [Batch %d/%d] [D adv: %f, aux: %f] [G loss: %f, adv: %f, aux: %f, cycle: %f] [Learing Rate: D: %f, G: %f}] ETA: %s"
    sys.stdout.write(
        "\r[Epoch %d/%d] [Batch %d/%d] [D adv: %f] [G loss: %f] [Learing Rate: D: %f.5, G: %f.5}] ETA: %s [RMSE: %f]"
        % (
            e,
            num_epoch,
            e,
            len(train_dl),
            loss_d, #.item(),
            #loss_d_cls.item(),
            loss_g, #.item(),
            #loss_g_adv.item(),
            #loss_g_cls.item(),
            #loss_g_rec.item(),
            lr_D,
            lr_G,
            time_left,
            image_rmse
        ))
    
    wandb.log({"num_epoch": num_epoch, 
               "Disc_Loss":loss_d,
               "Gen_Loss":loss_g,
               "RMSE": image_rmse, 
               #"SSIM":ssim_img
              })
    
    sample_interval = 1
    if batches_done % sample_interval == 0: #i in range(3):
        liveloss.update(history)
        liveloss.send()
        print(f"Epoch: {e}")
        
        #clear_output()
        
        '''visualize(
        
        real_image =real_img, #np.squeeze(), #colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
        fake_image = fake_img, #np.squeeze() #reverse_one_hot(mask)
        input_image = input_img,
        )'''
        #print(f'output images shape{fake_img.shape[1]}')
        
        '''input_image = torch.reshape(input_img, (input_img.shape[1],input_img.shape[2],input_img.shape[2]))
        #print(f'input_img images  shape {input_image.shape}')
        fake_img = torch.reshape(fake_img, (fake_img.shape[0],fake_img.shape[2],fake_img.shape[2]))
        real_img = torch.reshape(real_img, (real_img.shape[0],real_img.shape[2],real_img.shape[2]))'''
        #print(f'input_img images  shape {real_A.shape}')
        real_A = torch.reshape(real_A, (real_A.shape[0],real_A.shape[2],real_A.shape[2])).squeeze()
        #print(f'input_img images  shape {real_A.shape}')
        fake_B = torch.reshape(fake_B, (fake_B.shape[0],fake_B.shape[2],fake_B.shape[2]))
        real_B = torch.reshape(real_B, (real_B.shape[0],real_B.shape[2],real_B.shape[2]))
                 
        #print(f'target images  shape 2 {real_B.shape}')
        show_img_sample3(real_A, fake_B, real_B)
        
        #print(f'target images  shape 2 {real_B}')
    
        #if batches_done % sample_interval == 0:        
    
        #visualise_output(fake_img.data[:25],10, 10) 
        #visualize(
        #input_image = np.squeeze(real_A, axis=0), #real_A,
        #real_image = np.squeeze(real_B,axis=0), #np.squeeze(), #colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
        #fake_image = np.squeeze(fake_B, axis=0) #np.squeeze() #reverse_one_hot(mask)
        #)
    
        '''root = './generated/'
        for file in  glob(root+"fake*.png" and root+"real*.png"):
            B = (re.findall('[0-9]', file)[0])
            filename = os.path.join(root, f"fake{str(B)}.png")
            #filename1 = os.path.join(root, f"real{str(B)}.png")
            visualise_output(filename,16, 16)
            #visualise_output(filename1,16, 16)'''
            
            
        #./generated
        #f"fake{str(B)}.png
    
    
    '''
    # If at sample interval sample and save image
    if batches_done/sample_interval > batches_output:
        batches_output=batches_output+1
        clear_output()
        visualise_output(sample_images(batches_done), 30,30)'''
    
    
    
####
# Save models with GAN loss less than or equal to 0.8
if loss_g <= 0.7 and loss_g > 0.07 and loss_d >= 0.5: #
    show_losses(total_loss_g, total_loss_d, image_rmseT)
    print("GAN loss at epoch ", e, 'is ', loss_g, ' RMSE: ', image_rmse)
    #print("Disc loss at epoch ", e, 'is ', loss_d, ' SSIM: ', ssim_img)
#if e%1 == 0:
    saving_model(D, G, e)
    saving_img(real_A, fake_B, real_B, e+1)
    print('Great! \n Model and images saved !')
    try:
        result["G"] = total_loss_d 
        result["D"] = total_loss_g 
        result["RMSE"] = image_rmseT 
        saving_logs(result)
        show_losses(total_loss_g, total_loss_d, image_rmseT)
        saving_model(D, G, e)
        print("successfully save model at epoch", e)
    finally:
        return G, D 
#wandb.finish()    
   '''

I don’t know what’s causing the issue and your previously posted code runs fine for me. Are you able to reproduce the error using your posted code by calling test() and test_D()? If not, could you post a new minimal and executable code snippet reproducing the error?

I can’t reproduce the error with the test functions. Here is the error:
Size of skip_connection4: torch.Size([16, 512, 10, 10])
Size of d6 before concat: torch.Size([16, 512, 10, 10])
Size of d6 after concat: torch.Size([16, 1024, 10, 10])
Size of skip_connection3: torch.Size([16, 256, 20, 20])
Size of d7 before concat: torch.Size([16, 256, 20, 20])
Size of d7 after concat: torch.Size([16, 512, 20, 20])
Size of skip_connection2: torch.Size([16, 128, 40, 40])
Size of d8 before concat: torch.Size([16, 128, 40, 40])
Size of d8 after concat: torch.Size([16, 256, 40, 40])
Size of skip_connection1: torch.Size([16, 64, 80, 80])
Size of d9 before concat: torch.Size([16, 64, 80, 80])
Size of d9 after concat: torch.Size([16, 128, 80, 80])


RuntimeError Traceback (most recent call last)
/tmp/ipykernel_157888/3384618826.py in
3 EPOCH = 101
4 #BATCH_SIZE = 8 #3
----> 5 trained_G, trained_D = train_loop(train_dl, G, D, EPOCH)
6 ###

/tmp/ipykernel_157888/97467904.py in train_loop(train_dl, G, D, num_epoch, lr, lr_D, betas, lr_gamma)
37
38 for e in range(num_epoch):
—> 39 loss_g, loss_d, real_A, fake_B, real_B, lr_D, lr_G = train_fn(train_dl,G,D,L1_loss,GAN_loss, optimizer_g, optimizer_d)
40 total_loss_d.append(loss_d)
41 total_loss_g.append(loss_g)

/tmp/ipykernel_157888/3538503000.py in train_fn(train_dl, G, D, GAN_loss, L1_loss, optimizer_g, optimizer_d)
28 # GAN loss
29 fake_B = G(real_A)
—> 30 pred_fake_g = D(fake_B, real_A).squeeze() # ROkwen added .squeeze()
31
32 pred_real_label = D(real_B, real_A).squeeze()

/opt/miniconda3/envs/opence-v1.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
→ 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),

/tmp/ipykernel_157888/2091396250.py in forward(self, x, x1)
36 x = torch.cat((x, x1), dim=1)
37
—> 38 out1 = self.layer1(x)
39
40

/opt/miniconda3/envs/opence-v1.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
→ 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),

/opt/miniconda3/envs/opence-v1.0.0/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
115 def forward(self, input):
116 for module in self:
→ 117 input = module(input)
118 return input
119

/opt/miniconda3/envs/opence-v1.0.0/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
→ 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),

/opt/miniconda3/envs/opence-v1.0.0/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
417
418 def forward(self, input: Tensor) → Tensor:
→ 419 return self._conv_forward(input, self.weight)
420
421 class Conv3d(_ConvNd):

/opt/miniconda3/envs/opence-v1.0.0/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
413 weight, self.bias, self.stride,
414 _pair(0), self.dilation, self.groups)
→ 415 return F.conv2d(input, weight, self.bias, self.stride,
416 self.padding, self.dilation, self.groups)
417

RuntimeError: Given groups=1, weight of size [16, 2, 5, 5], expected input[16, 5, 80, 80] to have 2 channels, but got 5 channels instead

The issue is raised in Discriminator.layer1 as it expects an input with 2 channels while the activation is created via x = torch.cat((x, x1), dim=1) resulting in 5 channels.
Check the shapes of x and x1 and make sure the number of channels from the concatenated tensor corresponds to what layer1 expects. If you get stuck, update the executable code and make sure it raises the error.