Loss.backward() mismatch

Hello there!

I am really new in PyTorch. When training the discriminator Network for a GAN I am running with an error in the shapes of the gradient generated for the backpropagation.

RuntimeError: Function AddmmBackward returned an invalid gradient at index 1 - got [1, 16] but expected shape compatible with [1, 5520]

My inputs are images of the following (batch_size, channels, height, width): [1, 3, 270, 387]

This is the classification model defined for the discrimination task:

class Discriminator(nn.Module):
  def __init__(self, im_chan = 3, hidden_dim = 32):
    super().__init__()

    self.model = nn.Sequential(

        nn.Conv2d(3, 32, kernel_size = 3, stride = 2),
        nn.LeakyReLU(0.2),
        nn.BatchNorm2d(32),

        nn.Conv2d(32, 64, kernel_size = 3, stride = 2),
        nn.LeakyReLU(0.2),
        nn.BatchNorm2d(64),

        nn.Conv2d(64, 32, kernel_size = 3, stride = 2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(32, 16, kernel_size = 3, stride = 2),
        nn.LeakyReLU(0.02),

        nn.Flatten(),
        nn.Linear(16, 1),
        nn.Sigmoid()

    )
    self.loss_function = nn.BCELoss()

    self.optimiser = torch.optim.Adam(self.parameters(), lr = 0.0001)

    self.counter = 0

    self.progress = []

def forward(self, inputs):

    return self.model(inputs)

def train(self, inputs, targets):

    outputs = self.forward(inputs)

    print(outputs.shape)

    targets = targets.unsqueeze(0)

    print(targets.shape)

    loss = self.loss_function(outputs, targets)
    print('Loss calculated!')
    self.counter += 1

    if (self.counter % 10 == 0):
      self.progress.append(loss.item())

    if (self.counter % 1000 == 0):
      print('counter = ', self.counter)

    self.optimiser.zero_grad()
    loss.backward()
    self.optimiser.step()

When training the discriminator, the target is defined as a torch.cuda.FloatTensor([1.0]) for real images and torch.cuda.FloatTensor([0.0]) for the fake ones.

I am stuck here. Any help would be much appreciated!

Hi,

Can you share a full script that we can run to reproduce the issue?