Target size (torch.Size([4, 512, 512])) must be the same as input size (torch.Size([4, 1, 512, 512]))


ValueError Traceback (most recent call last)
/tmp/ipykernel_24/1691612924.py in
1 model_trainer = Trainer(model)
----> 2 model_trainer.start()

/tmp/ipykernel_24/785556323.py in start(self)
81 def start(self):
82 for epoch in range(self.num_epochs):
—> 83 self.iterate(epoch, “train”)
84 state = {
85 “epoch”: epoch,

/tmp/ipykernel_24/785556323.py in iterate(self, epoch, phase)
60 for itr, batch in enumerate(dataloader):
61 images, targets = batch
—> 62 loss, outputs = self.forward(images, targets)
63 loss = loss / self.accumulation_steps
64 if phase == “train”:

/tmp/ipykernel_24/785556323.py in forward(self, images, targets)
44 masks = targets.to(self.device)
45 outputs = self.net(images)
—> 46 loss = self.criterion(outputs, masks)
47 return loss, outputs
48

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = ,

/tmp/ipykernel_24/3657000512.py in forward(self, input, target)
33
34 def forward(self, input, target):
—> 35 loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
36 return loss.mean()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = ,

/tmp/ipykernel_24/3657000512.py in forward(self, input, target)
17 if not (target.size() == input.size()):
18 raise ValueError(“Target size ({}) must be the same as input size ({})”
—> 19 .format(target.size(), input.size()))
20 max_val = (-input).clamp(min=0)
21 loss = input - input * target + max_val + \

ValueError: Target size (torch.Size([4, 512, 512])) must be the same as input size (torch.Size([4, 1, 512, 512]))

The error mesage is describing the problem. Do you see the extra 1 in your in your input size? You might want to either input = input.squeeze() to remove the extra 1 so your shapes match. Or unsqueeze your target to add an extra 1 which should work via target = target.unsqueeze(1)