Target Size different to Input Size when computing discriminator loss for DCGAN

I’m a high schooler who’s (very!) new to machine learning and PyTorch, so I’ve been trying to build a DCGAN following the PyTorch official tutorial to replicate samples of melanoma lesions from the SIIM-ISIC Melanoma Classification dataset. However, I keep getting this error when trying to compute my discriminator loss on real data:


Using a target size (torch.Size([64])) that is different to the input size (torch.Size([14400])) is deprecated. Please ensure they have the same size.

Turns out my ‘output’ shape is 14400, but I’m not sure how it got there, although I’d guess it’s an issue from the way I load data or the layers in my discriminator. I’ve tried looking at some of the other related posts on here, but I’m not quite certain that those match my situation. I’m also inclined to believe that the issue is more likely with the way I’m loading data, as the main generator/discriminator/train code I followed from the aforementioned PyTorch DCGAN tutorial. I was wondering if I could get a few pointers in the right direction? Anything would help! Also, if there’s anything I should amend or add to this post, apologies in advance, please let me know.

I’m using a batch size of 64, a BCELoss function, and I’m not sure if this would help, but here is my train_transform:

train_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((299, 299)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

Here’s a row of the csv file I use (each ISIC_number corresponds to a jpg image with the same number), with the columns (image_name, patient_id, sex, age_approx, anatom_site_general_challenge, diagnosis, benign_malignant, target):

ISIC_0149568 IP_0962375 female 55.0 torso melanoma malignant 1

I’m also using this to create, initialize, and load my dataset:

class PreGANDataset(torch.utils.data.Dataset):
  def __init__(self, csv_file_path, root_path, transform=None):
    self.annotations = pd.read_csv(csv_file_path)
    self.root_path = root_path
    self.transform = transform
  def __len__(self):
    return len(self.annotations)
  def __getitem__(self, index):
    img_path = os.path.join(self.root_path, self.annotations.iloc[index, 0] + ".jpg")
    img = Image.open(img_path)
    y_label = torch.tensor(int(self.annotations.iloc[index, 7]))
    if self.transform:
      img = self.transform(img)
    return img, y_label

train_dataset = PreGANDataset(csv_file_path="/content/malignant.csv", root_path="/content/train_and_test_folder/train",
                              transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

Here is my discriminator and train function code:

class Discriminator(nn.Module):
    def __init__(self, NUM_GPU):
        super(Discriminator, self).__init__()
        self.NUM_GPU = NUM_GPU
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False), 
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

DISCRIM = Discriminator(NUM_GPU).to(DEVICE)
if (DEVICE.type == 'cuda') and (NUM_GPU > 1):
    DISCRIM = nn.DataParallel(DISCRIM, list(range(NUM_GPU)))
DISCRIM.apply(weights_initialization)
print(DISCRIM)

FIXED_NOISE = torch.randn(64, Z_SIZE, 1, 1, device=DEVICE)
GEN_OPTIMIZER = torch.optim.Adam(
    MODEL.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
DIS_OPTIMIZER = torch.optim.Adam(
    DISCRIM.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))


def train():
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    for epoch in range(NUM_EPOCHS):
        for i, data in enumerate(train_dataloader, 0):
            DISCRIM.zero_grad()
            rg = data[0].to(DEVICE)
            bs = rg.size(0)
            label = torch.full((bs,), REAL_LABEL,
                               dtype=torch.float, device=DEVICE) #REAL_LABEL is set to 1
            output = DISCRIM(rg).view(-1)
            dis_cost_real = LOSS_FUNCTION(output, label) # Right here is where this error is raised, my guess is that I'm using the wrong shape for output, but I'm not sure how (since the number referenced in the error in 14400)
            dis_cost_real.backward()
            dx = output.mean().item()

            noise = torch.randn(bs, Z_SIZE, 1, 1, device=DEVICE)
            fake = MODEL(noise)
            label.fill_(FAKE_LABEL) #FAKE_LABEL is set to 0
            output = DISCRIM(fake.detach()).view(-1)
            dis_cost_fake = LOSS_FUNCTION(output, label)
            dis_cost_fake.backward()
            dgz = output.mean().item()
            dis_cost = dis_cost_fake + dis_cost_real
            DIS_OPTIMIZER.step()

            MODEL.zero_grad()
            label.fill_(REAL_LABEL)
            output = DISCRIM(fake).view(-1)
            gen_cost = LOSS_FUNCTION(output, label)
            gen_cost.backward()
            dgz2 = output.mean().item()
            GEN_OPTIMIZER.step()

            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, NUM_EPOCHS, i, len(train_dataloader),
                         dis_cost.item(), gen_cost.item(), dx, dgz, dgz2))

            G_losses.append(gen_cost.item())
            D_losses.append(dis_cost.item())

            if (iters % 500 == 0) or ((epoch == NUM_EPOCHS-1) and (i == len(train_dataloader)-1)):
                with torch.no_grad():
                    fake = MODEL(FIXED_NOISE).detach().to(DEVICE)
                img_list.append(torchvision.utils.make_grid(
                    fake, padding=2, normalize=True))

            iters += 1

train()

Hi kpullela!

14400 is 64 * 15**2, so I expect that your discriminator is returning batches
of “images” of shape [nBatch = 64, nChannels = 1, H = 15, W = 15].

Your Discriminator is fully convolutional, so the height and width of its output
will depend on the height and width of its input. Looking through the numbers,
I expect that you are passing in batches of samples that have shapes equal to
[nBatch = 64, nChannels = 3, H = 512, W = 512].

If you know that your input images will always have this size (or you can ensure
it by resizing, padding, etc.) then you know what the shape of the output will be.

In any event, you want your Discriminator to output a single value for each
sample in the batch, so a shape of [64] for a batch size of 64.

Let’s assume that you do always pass in batches of images of shape
[nBatch, 3, 512, 512] (nBatch can vary, but the number of channels,
height, and width have to be fixed) and that Discriminator therefore outputs
(batches of) images of height and width both 15.

You should add a Flatten (1) layer to the end your network to “flatten” all
but the nBatch dimension (giving a shape of [nBatch, 225]) followed by
a Linear (255, 1) (giving a shape of [nBatch, 1], after which you would
likely squeeze (-1) the trailing singleton dimension away). This will output a
shape of [nBatch], that is, a single predicted value for each sample in the
batch, that will match the shape of the target that you are passing to your
LOSS_FUNCTION.

(output = DISCRIM (rg) will suffice without the .view(-1).)

Last, use BCEWithLogitsLoss (instead of BCELoss) for better numerical
stability, and, correspondingly, get rid of the Sigmoid().

So, something like:

        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False), 
            nn.Flatten (1)
            nn.Linear (255, 1)
        )

Good luck!

K. Frank

Hi KFrank!

First off, just wanted to say that I really appreciate the in-depth response, as I’ve definitely gained some new insight into my code from it.

So, I went through my code and made some edits. I changed my train_transform to resize my image to
512, 512, and updated my discriminator with the Flatten(1) and linear layer instead of Sigmoid() (in addition to changing the loss function and removing .view(-1) - thanks for the tips!).

However, it seems that Flatten(1) gives a shape of 64, 841 (which I’d assume to be the result of 29*29 that comes from the previous layer - but please correct me if I’m mistaken), and therefore throws an error since the 64, 841 and 255, 1 matrices cannot be multiplied.

Is my understanding correct so far? Or have I perhaps missed something else that led to this multiplication error? Apologies if this is a fairly straightforward question, and again, thanks for your guidance.

Hi kpullela!

That was my mistake – you do want your input images to have H = W = 299.
(I had miscalculated the output size of the last Conv2 layer and overlooked
the transforms.Resize((299, 299)) in your original post.)

Yes, you are correct. 512 * 512 input images lead to 29 * 29 images after the
final Conv2 so you should stick with your original 299 * 299 input images to
get 15 * 15 after the final Conv2 which will match the 255 of the Linear layer
after Flatten.

Best.

K. Frank

@kpullela10 Just to add to the excellent answer from @KFrank , you can dynamically resize the images as a layer by using nn.AdaptivePool2d either before the conv2d layers, after, or somewhere in the middle. The way it works is that you specify the size you want out. For example:

...
self.NUM_GPU = NUM_GPU
self.ap=nn.AdaptivePool2d((299,299))
self.main = nn.Sequential( 
...
def forward(self, input):
        input=self.ap(input)
        return self.main(input)
...

Each location to place this layer has it’s advantages and disadvantages. Placing it at the beginning will make the model stronger on the macro picture feature representations/detection but weaker on learning textures.

At the end will make the model good on textures but weaker on macro detection.

In the middle gives you a better mix of both macro and texture detection.

Hi KFrank,

Once again, thanks for the clarification. It’s now working, thanks to your suggestions.

Just a quick follow-up however, it’s with my generator.
(for reference, my generator code):

class Generator(nn.Module):
    def __init__(self, NUM_GPU):
        super(Generator, self).__init__()
        self.NUM_GPU = NUM_GPU
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * \
                               4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * \
                               2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

(the place where I’m passing the generator’s results into the discriminator, inside of my train function):

            noise = torch.randn(64, 100, 1, 1, device=DEVICE)
  
            print(noise.shape) # ---> torch.Size([64, 100, 1, 1])
            fake = MODEL(noise)
            print(fake.shape) # ---> torch.Size([64, 3, 64, 64])
            label.fill_(FAKE_LABEL)
            output = DISCRIM(fake.detach())

I’ve assumed that I need to make similar changes since when I try to pass the fake data generated by the generator into the discriminator, I get a matrix multiplication error in the last linear of the discriminator: 64x1 against 225x1 (also, I presume you meant 225 instead of 255 above for 15^2?).

After looking at the shape of the fake data generated, its height and width are both 64, whereas I should be expecting H = W = 299 if I’ve understood this correctly.

I’ve been playing around with resizing fake to 299x299, but that doesn’t seem to be a great approach, as errors were raised since 64x3x64x64 (786,432 items) can’t be reshaped into 299x299 (89401 items).

My next step was to then change the layers of the generator, but this is where I got stuck, as I’m not sure how to change my layers to get 299.

I guess my question is, am I looking in the right direction? Or is the error actually a result of something else, and not my generator?

Again, really appreciate the help. Please let me know if I can clarify anything!

Hi J_Johnson,

Thanks for the tip! I think I’m currently running into some shape misalignments with my generator, so I’ll try this out, and see how it goes.

There are two approaches for getting the exact size out of your layers:

  1. Just use a print statement at the point in question - i.e. trial and error.
  2. Calculate for each layer using the docs. For example, for Conv2d it is:

Hi kpullela!

I wouldn’t say that the error is in your generator, per se. Rather, your generator
and discriminator don’t match, in that your generator produces images with
spatial size H = W = 64, while your discriminator expects H = W = 299.

Some perspective on what is going on:

Your discriminator has a number of Con2d layers. Because that have stride
greater than 1, they are, in sense, down-sampling your images, reducing H
and W (while encoding some of the spatial information into the increasingly
large channels dimension).

Conversely, the ConvTranspose2d layers are up-sampling the (initially 1 * 1)
“noise” images you pass in (moving channels information into the increasing
spatial extent).

The main driver in the change in image size is the stride parameter, but the
kernel_size and padding parameters also affect the spatial size of the output
of the convolutional layers.

You need to chain together the appropriate up-sampling (generator) and
down-sampling (discriminator) layers together to get your generator and
discriminator to get sizes that are consistent with one another.

For debugging and development purposes, I would get rid the Sequentials
that wrap your layers and chain them together explicitly in your forward()
methods. This way you can add print() statements that track the image
sizes as they pass through the layers.

Here is a simple illustration of this:

>>> import torch
>>> print (torch.__version__)
1.13.1
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> class Model (torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.conv1 = torch.nn.Conv2d(3, 64, 4, 2, 1, bias=False)
...         self.relu1 = torch.nn.LeakyReLU(0.2, inplace=True)
...         self.conv2 = torch.nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False)
...         self.relu2 = torch.nn.LeakyReLU(0.2, inplace=True)
...         self.conv3 = torch.nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False)
...
...     def forward (self, x):
...         print (x.shape)
...         x = self.conv1 (x)
...         print (x.shape)
...         x = self.relu1 (x)
...         x = self.conv2 (x)
...         print (x.shape)
...         x = self.relu2 (x)
...         x = self.conv3 (x)
...         print (x.shape)
...         return x
...
>>> model = Model()
>>>
>>> input = torch.randn (1, 3, 299, 299)
>>>
>>> y = model (input)
torch.Size([1, 3, 299, 299])
torch.Size([1, 64, 149, 149])
torch.Size([1, 128, 74, 74])
torch.Size([1, 256, 37, 37])

I assume that you have a dataset of real images. If these are all of the same
size, I would use that size for the generator output size and discriminator input
size, unless that size in unwieldy. (There’s nothing special about 299 * 299.)
If the sizes of your real images vary, then using Resize() makes sense, but I
would resize to a spatial extent that is “typical” of your real images.

As Jay suggested, look at the documentation for Conv2d and ConvTranspose2d
to get a sense of how the interplay between the various parameters determines
the size of the output image. That way, you can design your layers in an informed
way to do what you want.

As an aside, remember that you will be training your generator to generate images
that look like whatever gets input to your discriminator as “real” images. So if you
use various transforms as part of your dataloader pipeline, you will be training your
generator to mimic the results of those transforms (rather than the raw images in
your untransformed dataset).

Best.

K. Frank

Got it. Thanks, will do.

Ah, ok. I’m going to then reduce the amount of transforming I do before passing images to the discriminator. Thanks!