Target size (torch.Size([3, 3, 256, 256])) must be the same as input size (torch.Size([3, 65536]))

Hello!

I’m new to machine learning and PyTorch, and I’m stuck on this error which seems really simple but I can’t find where to fix it:

ValueError                                Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_52832/2153252887.py in <module>
      5 n_epochs  = 10
      6 
----> 7 train_unet(unet, trainloader, valloader, optimizer, loss, n_epochs=n_epochs)
      8 # test_unet(mlp_net, testloader, loss)

~\AppData\Local\Temp/ipykernel_52832/3077527990.py in train_unet(net, trainloader, valloader, optimizer, loss_function, n_epochs)
     24 
     25             y_pred = net(X.float())
---> 26             y_loss = loss_function(y_pred, y)
     27             y_dice = dice(y, y_pred)
     28 

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
    711 
    712     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 713         return F.binary_cross_entropy_with_logits(input, target,
    714                                                   self.weight,
    715                                                   pos_weight=self.pos_weight,

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   3128 
   3129     if not (target.size() == input.size()):
-> 3130         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   3131 
   3132     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([3, 3, 256, 256])) must be the same as input size (torch.Size([3, 65536]))

My model returns an output of size torch.Size([3, 65536]), while the labels are images of torch.Size([3, 3, 256, 256]).

The images are of size torch.Size([3, 3, 256, 256]) and not torch.Size([3, 256, 256]). The first 3 is the batch size, the second one is for the 3 RGB channels, and the 256s are the image dimensions.

Clearly, the model returns a size of 65536, which is 256*256, converting the image to a line. But how do I convert it back to a square? And how do I get rid of the batch size’s 3?

Here’s a reproducible code:

import torch
import torch.nn as nn

from torch.nn                    import Linear, Module, ModuleList, ReLU, Sequential, BCEWithLogitsLoss, Conv2d, ConvTranspose2d, MaxPool2d
from torch.optim                 import SGD, Adam
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset    import Dataset

from skimage.io        import imread
from skimage.transform import resize

from glob import glob
import os

import matplotlib.pyplot as plt
import numpy as np



class TissueDataset(Dataset):
    def __init__(self, img_path, target_path):
        super().__init__()
        self.imgs    = glob(os.path.join(img_path, "*.jpg"))
        self.targets = glob(os.path.join(target_path, "*.jpg"))
    
    def __getitem__(self, idx):
        size = (3, 256, 256)

        image = imread(self.imgs[idx])
        label = imread(self.targets[idx])

        image = resize(image, size, order = 1, preserve_range = True)
        label = resize(label, size, order = 0, preserve_range = True).astype(int)

        return image, label
    
    def __len__(self):
        return len(self.imgs)



def makeLoader(type):
    dataloader = DataLoader(TissueDataset(
        img_path    = f'data/tissue/{type}/jpg',
        target_path = f'data/tissue/{type}/lbl'
    ), batch_size = 3, shuffle = True)

    return dataloader


trainloader = makeLoader('train')
valloader   = makeLoader('val')
testloader  = makeLoader('test')



class DoubleConv2d(Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv2d, self).__init__()

        self.stack = Sequential(
            Conv2d(in_channels, out_channels, 3, 1, 1),
            ReLU(),
            Conv2d(out_channels, out_channels, 3, 1, 1),
            ReLU()
        )


    def forward(self, x):
        return self.stack(x)



class UNet(Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        inn = 3
        out = 1
        mid = [112, 224, 448]

        self.encoder = ModuleList()
        self.bottom  = DoubleConv2d(mid[-1], 2*mid[-1]) # should both be mid[-1]?
        self.decoder = ModuleList()
        self.end     = Conv2d(mid[0], 1*out, 1)

        self.maxpool = MaxPool2d(2, 2)
        #self.linear  = Linear(65536, 10)

        for dim in mid:
            self.encoder.append(DoubleConv2d(inn, dim))
            inn = dim

        for dim in mid[::-1]:
            self.decoder.append(ConvTranspose2d(2*dim, dim, 2, 2))
            self.decoder.append(DoubleConv2d(2*dim, dim))


    def forward(self, x):
        connections = []

        for i in range(len(self.encoder)):
            module = self.encoder[i]

            x = module(x)
            connections.append(x)
            x = self.maxpool(x)

        x = self.bottom(x)

        for i in range(len(self.decoder)):
            module = self.decoder[i]
            x = module(x)

            if i % 2 == 0: # ConvTranspose2d
                connection = connections.pop()
                x = torch.cat((connection, x), dim=1)
        
        x = self.end(x)

        x = x.view(x.size(0), -1)
        #x = self.linear(x)

        return x



# print(UNet())
# unet = UNet()
# x = torch.randn(3, 3, 256, 256)
# out = unet(x)



def dice(ytrue, ypred):
    if ytrue.dim() == 2:
        inter = torch.dot(ytrue[i].reshape(-1), ypred[i].reshape(-1))

        cardA = torch.sum(ytrue)
        cardB = torch.sum(ypred)

        num = 2 * inter     + 1e-6
        den = cardA + cardB + 1e-6

        return num / den

    elif ytrue.dim() == 3:
        batch_size = ytrue.shape[0]
        total = 0

        for i in range(batch_size):
            total += dice(ytrue[i, ...], ypred[i, ...])

        return total / batch_size
    
    return None



def train_unet(net, trainloader, valloader, optimizer, loss_function, n_epochs):
    net = net.float()

    t_size = len(trainloader.dataset)
    v_size = len(valloader.dataset)

    graphs = {
        'train_loss' : [],
        'val_loss'   : [],
        'train_dice' : [],
        'val_dice'   : []
    }

    for epoch in range(1, n_epochs+1):
        net.train()
        y_loss = 0
        y_dice = 0
        w_loss = 0
        w_dice = 0

        for X_batch, (X, y) in enumerate(trainloader):
            X, y = X.to(device), y.to(device)
            
            y_pred = net(X.float())
            y_loss = loss_function(y_pred, y)
            y_dice = dice(y, y_pred)

            optimizer.zero_grad()
            y_loss.backward()
            optimizer.step()

            if X_batch % 100 == 0:
                net.eval()
                w_loss = 0
                w_dice = 0

                for V_batch, (V, w) in enumerate(valloader):
                    V, w = V.to(device), w.to(device)

                    w_pred  = net(V.float())
                    w_loss  = loss_function(w_pred, w)
                    w_dice += dice(w, w_pred)

                w_loss = w_loss / v_size
                w_dice = w_dice / v_size

            graphs['train_loss'].append(y_loss.item())
            graphs['train_dice'].append(y_dice)
            graphs['val_loss'].append(w_loss.item())
            graphs['val_dice'].append(w_dice)



unet      = UNet()

optimizer = Adam(unet.parameters(), lr=0.001)
loss      = BCEWithLogitsLoss()
n_epochs  = 10

train_unet(unet, trainloader, valloader, optimizer, loss, n_epochs=n_epochs)

Or if you prefer it as a file: https://we.tl/t-0oAjeZVZQ9
And here is the data: https://we.tl/t-9g3tUYCJ3d


Edit 1: Getting rid of the line x = x.view(x.size(0), -1) and using a batch size of 4 instead of 3 for clarity, the error becomes:

Target size (torch.Size([4, 3, 256, 256])) must be the same as input size (torch.Size([4, 1, 256, 256]))

Edit 2: Using sizes (3, 256, 256) for images and (1, 256, 256) for labels, and removing .astype(int) from the __getitem__ method gives this error:

TypeError: conv2d() received an invalid combination of arguments - got (builtin_function_or_method, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!builtin_function_or_method!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!builtin_function_or_method!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)

Hi Chris!

Based on your edits, it looks like you are on the right track.

Let me make some comments based on my speculation about your use
case and what you are trying to do.

Based on the name TissueDataset, the name UNet, and the use of
BCEWithLogitsLoss as your loss criterion, I assume that you are
performing binary semantic segmentation. That is, you wish to classify
each pixel in your input as being in either “class 0” (background or healthy
tissue or whatever) or “class 1” (foreground or diseased tissue or whatever).

As a general rule (and as a requirement for convolutional networks),
pytorch networks work with batches of inputs (and outputs and labels).
If you want to process a single image, you still need to package it as a
batch with a batch size of one.

The input to your model should have shape [nBatch, nChannels, H, W],
where in your case nChannels = 3 and presumably correspond to the RGB
channels of a color image, and have type float (or double).

Your output and labels should both have shape [nBatch, 1, H, W] (or just
[nBatch, H, W]), and also be of type float.

Note, UNet does not (typically) have H and W wired into it – the same
UNet can be trained on, and perform inference on images of differing
shapes – but any given batch has to consist of images of the same
shape.

So, yes, get rid of the final x = x.view(x.size(0), -1) (and the
commented-out x = self.linear(x)), and have your model output
the result of the final Conv2d (self.end) layer. And, yes, out = 1 is
correct.

If your labels are indeed (batches of) images of shape [3, 3, 256, 256],
then you have to figure out how they are “encoded” to give binary class
labels. Could they be pure black-and-white images that happen to be
encoded as three-channel RGB images?

In any event, you have to process your labels “images” to be single-channel
binary labels (of type float). (Your labels don’t actually have to be pure
binary, that is, exactly zero or one – they can be probabilistic labels that
run from zero to one.)

Best.

K. Frank

1 Like

Thank you so much for your detailed answer! I am indeed trying to solve a segmentation problem. Here is an example of an image and its label:

17997 17997

If I open a label image on Photoshop, I can see the image mode is set to Grayscale and not RGB, so the shape should be [3, 1, 256, 256], right?

And this is what I see when I print(y_pred.shape, y.shape), so it’s probably correct.

torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])

All looks fine now, but I still have this error I can’t make sense of:

TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_26916/903322469.py in <module>
      5 n_epochs  = 10
      6 
----> 7 train_unet(unet, trainloader, valloader, optimizer, loss, n_epochs=n_epochs)
      8 # test_unet(unet, testloader, loss)

~\AppData\Local\Temp/ipykernel_26916/239980266.py in train_unet(net, trainloader, valloader, optimizer, loss_function, n_epochs)
     40                     V, w = V.to(device), w.to(device)
     41 
---> 42                     w_pred  = net(V.float)
     43                     w_loss  = loss_function(w_pred, w)
     44                     w_dice += dice(w, w_pred)

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Local\Temp/ipykernel_26916/3729299638.py in forward(self, x)
     30             module = self.encoder[i]
     31 
---> 32             x = module(x)
     33             connections.append(x)
     34             x = self.maxpool(x)

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Local\Temp/ipykernel_26916/3683377085.py in forward(self, x)
     12 
     13     def forward(self, x):
---> 14         return self.stack(x)

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\conv.py in forward(self, input)
    445 
    446     def forward(self, input: Tensor) -> Tensor:
--> 447         return self._conv_forward(input, self.weight, self.bias)
    448 
    449 class Conv3d(_ConvNd):

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\conv.py in _conv_forward(self, input, weight, bias)
    441                             weight, bias, self.stride,
    442                             _pair(0), self.dilation, self.groups)
--> 443         return F.conv2d(input, weight, bias, self.stride,
    444                         self.padding, self.dilation, self.groups)
    445 

TypeError: conv2d() received an invalid combination of arguments - got (builtin_function_or_method, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!builtin_function_or_method!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!builtin_function_or_method!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)

For reference, this is the UNet I’m trying to build (ignoring the dropouts)

Nevermind, after hours of work I found the error - stupid error, as always :man_facepalming:

I was calling w_pred = net(V.float) instead of w_pred = net(V.float()).
It’s written right there in the error that I kept looking at but I kept missing it until I posted it here :smiling_face_with_tear:

Thanks a lot!