RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (4 x 4). Kernel size can't be greater than actual input size

I am trying to train GAN to transfer style. I am getting error when passing images through discriminator

for epoch in range(epochs):
        #code for stats
        for real_images in tqdm(t_dl):

            optimizer["discriminator"].zero_grad()

            real_preds = model["discriminator"](real_images)#-----------------------error here
#code

And here is model

model = {
    "discriminator": discriminator.to(device),
    "generator": generator.to(device)
}

And code for discriminator

discriminator = nn.Sequential(
    # in: 3 x 256 x 256
    PrintLayer(),
    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 64 x 128 x 128
    PrintLayer(),
    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 128 x 64 x 64
    PrintLayer(),
    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 256 x 32 x 32
    PrintLayer(),
    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 16 x 16
    PrintLayer(),
    nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(1024),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 8 x 8
    PrintLayer(),
    nn.Conv2d(1024, 1024, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(1024),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 1024 x 4 x 4
    PrintLayer(),
    nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # out: 1 x 1 x 1
    PrintLayer(),
    nn.Flatten(),
    nn.Sigmoid())

So I added PrintLayer() to check dimensions after convolutions

class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        print(x.shape)
        return x

All images in batch are 256*256, I printed images sizes right before passing them to discriminator

0 torch.Size([3, 256, 256])
1 torch.Size([3, 256, 256])
2 torch.Size([3, 256, 256])
3 torch.Size([3, 256, 256])
4 torch.Size([3, 256, 256])
5 torch.Size([3, 256, 256])
6 torch.Size([3, 256, 256])
7 torch.Size([3, 256, 256])
8 torch.Size([3, 256, 256])
9 torch.Size([3, 256, 256])

It works with first image but somehow second image is 112*112

torch.Size([10, 3, 256, 256])
torch.Size([10, 64, 128, 128])
torch.Size([10, 128, 64, 64])
torch.Size([10, 256, 32, 32])
torch.Size([10, 512, 16, 16])
torch.Size([10, 1024, 8, 8])
torch.Size([10, 1024, 4, 4])
torch.Size([10, 1, 1, 1])
torch.Size([10, 3, 112, 112])
torch.Size([10, 64, 56, 56])
torch.Size([10, 128, 28, 28])
torch.Size([10, 256, 14, 14])
torch.Size([10, 512, 7, 7])
torch.Size([10, 1024, 3, 3])
torch.Size([10, 1024, 1, 1])

The error is most likely raised in a conv layer when the input activation would create an empty output.
Your debugging steps look good and it seems to be the last conv layer.

This would explain the issue, as 112x112 might be too small and you should check why the “second image” has this smaller spatial size given you’ve apparently verified that all images are 256x256.

@ptrblck thanks for quick reply.

I am newbie in ml can you please explain what did you mean by

The error is most likely raised in a conv layer when the input activation would create an empty output.

How it can create empty output if i am passing image.

and you should check why the “second image” has this smaller spatial size given you’ve apparently verified that all images are 256x256 .

but that was the point of my question.I do not understand why given image is 112*112 if before discriminator i did

for i in range(len(real_images)):
    print(f'{i} {real_images[i].shape}')

and it printed 1 torch.Size([3, 256, 256]).

You did a fantastic job in trying to debug the error with the custom PrintLayer, so don’t worry about being new in ML.

The last conv layer would get an input activation of [batch_size, 1024, 1, 1] and tries to apply a conv kernel with a spatial size of 4x4 which will not work as the output spatial size would be smaller than 1 and thus empty:

conv = nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0, bias=False)
x = torch.randn(2, 1024, 1, 1)
out = conv(x)
# RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (4 x 4). Kernel size can't be greater than actual input size

That’s also a good check and does not explain where the smaller images are coming from.
Could you check if you are using another DataLoader loop (e.g. during validation or so) which could create smaller images?

This is how i prepared images

images = []

from skimage.io import imread

root = '/tmp/pix2pix'

for file in os.listdir('/tmp/pix2pix'):
    images.append(imread(os.path.join(root, file)))

for i in range(len(images)):
    images[i] = torch.Tensor(images[i])
    images[i] = images[i].permute(2, 1, 0) # from HWC to CHW

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
SIZE = 256

trs = tt.Compose([
    tt.ToPILImage(),
    tt.ToTensor(),
    tt.Resize((SIZE, SIZE)),
    tt.Normalize(mean, std)]
)
images = [trs(images[i]) for i in range(len(images))]

batch_size = 10
t_dl = DataLoader(images, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

I have only one Dataloader because it is GAN so i will look at losses graphs and use my own images to check how GAN will be transferring style.And error occured at first epoch.

This code snippet looks also good.
Could you add a print statement to the actual training loop and just check the shape of each input tensor to the model?
Are you overriding the input variables (e.g. images) later in your code somehow?

I rewrited code in OOP style to debug better,in class Discriminator in forward i added print statement.Here is everything that my code printed

0 torch.Size([3, 256, 256])
1 torch.Size([3, 256, 256])
2 torch.Size([3, 256, 256])
3 torch.Size([3, 256, 256])
4 torch.Size([3, 256, 256])
5 torch.Size([3, 256, 256])
6 torch.Size([3, 256, 256])
7 torch.Size([3, 256, 256])
8 torch.Size([3, 256, 256])
9 torch.Size([3, 256, 256])
image shape torch.Size([10, 3, 256, 256])
torch.Size([10, 3, 256, 256])
torch.Size([10, 64, 128, 128])
torch.Size([10, 128, 64, 64])
torch.Size([10, 256, 32, 32])
torch.Size([10, 512, 16, 16])
torch.Size([10, 1024, 8, 8])
torch.Size([10, 1024, 4, 4])
torch.Size([10, 1, 1, 1])
torch.Size([10, 1])
torch.Size([10, 1])
image shape torch.Size([10, 3, 112, 112])
torch.Size([10, 3, 112, 112])
torch.Size([10, 64, 56, 56])
torch.Size([10, 128, 28, 28])
torch.Size([10, 256, 14, 14])
torch.Size([10, 512, 7, 7])
torch.Size([10, 1024, 3, 3])
torch.Size([10, 1024, 1, 1])
class BaseBlock(nn.Module):
    def __init__(self, in_ch, out_ch, k_size=4, stride=2, padding=1, bias=False):
        super().__init__()

        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k_size, stride=stride, padding=padding, bias=bias)
        self.b_norm = nn.BatchNorm2d(out_ch)
        self.actv = nn.LeakyReLU(inplace=True)
        self.print = PrintLayer()

    def forward(self, x):

        x = self.print(x)
        x = self.conv(x)
        x = self.b_norm(x)
        x = self.actv(x)

        return x

class Discriminator(nn.Module):
    def __init__(self, chs=(3, 64, 128, 256, 512, 1024, 1024)):
        super().__init__()

        self.convs = [BaseBlock(chs[i], chs[i+1]).to(device) for i in range(len(chs) - 1)]
        self.last_conv = nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0, bias=False)

        self.sigm = nn.Sigmoid()
        self.fltn = nn.Flatten()
        self.print = PrintLayer()

    def forward(self, x):
        print(f'image shape {x.shape}')
        for _, conv in enumerate(self.convs):
            x = conv(x)
        
        x = self.print(x)
        x = self.last_conv(x)
        x = self.print(x)
        x = self.fltn(x)
        x = self.print(x)
        x = self.sigm(x)
        x = self.print(x)

        return x

Are you overriding the input variables (e.g. images ) later in your code somehow?

No,after preparing images i am defying model and train loop

def fit(model, criterion, epochs, lr, start_idx=1):
    model["discriminator"].train()
    model["generator"].train()
    torch.cuda.empty_cache()
    
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    optimizer = {
        "discriminator": torch.optim.Adam(model["discriminator"].parameters(), 
                                          lr=lr, betas=(0.5, 0.999)),
        "generator": torch.optim.Adam(model["generator"].parameters(),
                                      lr=lr, betas=(0.5, 0.999))
    }
    
    for epoch in range(epochs):
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        for real_images in tqdm(t_dl):
            # Train discriminator
            # Clear discriminator gradients
            optimizer["discriminator"].zero_grad()

            # Pass real images through discriminator
            for i in range(len(real_images)):
                print(f'{i} {real_images[i].shape}')
            real_images = real_images.to(device)
            real_preds = model["discriminator"](real_images)#----------error here
            real_targets = torch.ones(real_images.size(0), 1, device=device)
            real_loss = criterion["discriminator"](real_preds, real_targets)
            cur_real_score = torch.mean(real_preds).item()
            #code

OK, that’s interesting as it seems the second batch uses smaller images:

0 torch.Size([3, 256, 256])
1 torch.Size([3, 256, 256])
2 torch.Size([3, 256, 256])
3 torch.Size([3, 256, 256])
4 torch.Size([3, 256, 256])
5 torch.Size([3, 256, 256])
6 torch.Size([3, 256, 256])
7 torch.Size([3, 256, 256])
8 torch.Size([3, 256, 256])
9 torch.Size([3, 256, 256])
image shape torch.Size([10, 3, 256, 256])
torch.Size([10, 3, 256, 256])
torch.Size([10, 64, 128, 128])
torch.Size([10, 128, 64, 64])
torch.Size([10, 256, 32, 32])
torch.Size([10, 512, 16, 16])
torch.Size([10, 1024, 8, 8])
torch.Size([10, 1024, 4, 4])
torch.Size([10, 1, 1, 1])
torch.Size([10, 1])
torch.Size([10, 1])
image shape torch.Size([10, 3, 112, 112])  # !!!!! HERE !!!
torch.Size([10, 3, 112, 112])
...

so it seems the t_dl returns batches containing images in different shapes.

EDIT:
However, based on the print statements the smaller size is not coming from the calls here:

            # Pass real images through discriminator
            for i in range(len(real_images)):
                print(f'{i} {real_images[i].shape}')
            real_images = real_images.to(device)
            real_preds = model["discriminator"](real_images)

since the prints statements for each sample are missing so I guess you are calling model["discriminator"] later in the code which is not posted here.

full code of train loop

def fit(model, criterion, epochs, lr, start_idx=1):
    model["discriminator"].train()
    model["generator"].train()
    torch.cuda.empty_cache()
    
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    # Create optimizers
    optimizer = {
        "discriminator": torch.optim.Adam(model["discriminator"].parameters(), 
                                          lr=lr, betas=(0.5, 0.999)),
        "generator": torch.optim.Adam(model["generator"].parameters(),
                                      lr=lr, betas=(0.5, 0.999))
    }
    
    for epoch in range(epochs):
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        for real_images in tqdm(t_dl):
            # Train discriminator
            # Clear discriminator gradients
            optimizer["discriminator"].zero_grad()

            # Pass real images through discriminator
            for i in range(len(real_images)):
                print(f'{i} {real_images[i].shape}')
            real_images = real_images.to(device)
            real_preds = model["discriminator"](real_images)
            real_targets = torch.ones(real_images.size(0), 1, device=device)
            real_loss = criterion["discriminator"](real_preds, real_targets)
            cur_real_score = torch.mean(real_preds).item()
            
            # Generate fake images
            latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
            fake_images = model["generator"](latent)

            # Pass fake images through discriminator
            fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
            fake_preds = model["discriminator"](fake_images)
            fake_loss = criterion["discriminator"](fake_preds, fake_targets)
            cur_fake_score = torch.mean(fake_preds).item()

            real_score_per_epoch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)

            # Update discriminator weights
            loss_d = real_loss + fake_loss
            loss_d.backward()
            optimizer["discriminator"].step()
            loss_d_per_epoch.append(loss_d.item())


            # Train generator
            # Clear generator gradients
            optimizer["generator"].zero_grad()
            
            # Generate fake images
            latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
            fake_images = model["generator"](latent)
            
            # Try to fool the discriminator
            preds = model["discriminator"](fake_images)
            targets = torch.ones(batch_size, 1, device=device)
            loss_g = criterion["generator"](preds, targets)
            
            # Update generator weights
            loss_g.backward()
            optimizer["generator"].step()
            loss_g_per_epoch.append(loss_g.item())
            
        # Record losses & scores
        losses_g.append(np.mean(loss_g_per_epoch))
        losses_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))
        
        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, 
            losses_g[-1], losses_d[-1], real_scores[-1], fake_scores[-1]))
    
        # Save generated images
        if epoch == epochs - 1:
          save_samples(epoch+start_idx, fixed_latent, show=False)
    
    return losses_g, losses_d, real_scores, fake_scores

so I guess you are calling model["discriminator"] later in the code

EDIT:No, as i said before error occured on first epoch,so other part of train loop is not executed

But you are calling the discriminator again in the same loop as I’ve assumed:

real_preds = model["discriminator"](real_images)
...
fake_preds = model["discriminator"](fake_images)
...
preds = model["discriminator"](fake_images)

Again, check all calls to the discriminator and try to narrow down where the exact error is raised.
If in doubt, add print statements after each line of code and see where it crashes.
Based on the updated code I would start with checking fake_images.

Thanks for help @ptrblck!

I checked it and it was the problem.Generator generated images of wrong shape,so discriminator was getting 112*112.