ValueError: Target and input must have the same number of elements. target nelement (524288) != input nelement (209952)

i am getting this error at the code below duing the training , the error in the model since the shapes as following
image.shape = torch.Size([2, 1, 512, 512])

masks.shape = torch.Size([2, 1, 512, 512])

outputs.shape =torch.Size([2, 1, 324, 324])

def train_model(model, data_train, criterion, optimizer):
    """Train the model and report validation error with training error
    Args:
        model: the model to be trained
        criterion: loss function
        data_train (DataLoader): training dataset
    """
    model.train()
    for batch, (images, masks) in enumerate(data_train):
        images = Variable(images.float().cuda())
        print(images.shape)
        masks = Variable(masks.float().cuda()) 
        print(masks.shape)
        outputs = model(images)
        print(outputs.shape)
        #print(masks.shape, outputs.shape)
        loss = criterion(outputs, masks).float()
        optimizer.zero_grad()
        loss.backward()
        # Update weights
        optimizer.step()

and here is my model SegNet , what is the worng in the model that makes the 'outputs ’ in shape 324?


class SegNet(nn.Module):
    def __init__(self):
        super(SegNet, self).__init__()

        self.encoder_1 = nn.Sequential(
            nn.Conv2d(3, 64, 7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)
        )  # first group

        self.encoder_2 = nn.Sequential(
            nn.Conv2d(64, 64, 7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)
        )  # second group

        self.encoder_3 = nn.Sequential(
            nn.Conv2d(64, 64, 7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)
        )  # third group

        self.encoder_4 = nn.Sequential(
            nn.Conv2d(64, 64, 7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)
        )  # fourth group

        self.unpool_1 = nn.MaxUnpool2d(2, stride=2)  # get masks
        self.unpool_2 = nn.MaxUnpool2d(2, stride=2)
        self.unpool_3 = nn.MaxUnpool2d(2, stride=2)
        self.unpool_4 = nn.MaxUnpool2d(2, stride=2)

        self.decoder_1 = nn.Sequential(
            nn.Conv2d(64, 64, 7, padding=3),
            nn.BatchNorm2d(64)
        )  # first group

        self.decoder_2 = nn.Sequential(
            nn.Conv2d(64, 64, 7, padding=3),
            nn.BatchNorm2d(64)
        )  # second group

        self.decoder_3 = nn.Sequential(
            nn.Conv2d(64, 64, 7, padding=3),
            nn.BatchNorm2d(64)
        )  # third group

        self.decoder_4 = nn.Sequential(
            nn.Conv2d(64, 1, 7, padding=3),
            nn.BatchNorm2d(1)
        )  # fourth group
   
    def weight_init(self):
        for i,m in enumerate(self.modules()):
            if isinstance(m,nn.Conv2d):
                nn.init.xavier_normal(m.weight)
                nn.init.constant(m.bias, 0)

    def forward(self, x):
        size_1 = x.size()
        x, indices_1 = self.encoder_1(x)

        size_2 = x.size()
        x, indices_2 = self.encoder_2(x)
        size_3 = x.size()
        x, indices_3 = self.encoder_3(x)

        size_4 = x.size()
        x, indices_4 = self.encoder_4(x)

        x = self.unpool_1(x, indices_4, output_size=size_4)
        x = self.decoder_1(x)

        x = self.unpool_2(x, indices_3, output_size=size_3)
        x = self.decoder_2(x)

        x = self.unpool_3(x, indices_2, output_size=size_2)
        x = self.decoder_3(x)

        x = self.unpool_4(x, indices_1, output_size=size_1)
        x = self.decoder_4(x)

        x = F.sigmoid(x)

        return x

Using your model definition, I get an output of [batch_size, 1, 512, 512]:

model = SegNet()
x = torch.randn(1, 3, 512, 512)
output = model(x)
print(output.shape)
> torch.Size([1, 1, 512, 512])