Issue with loss function when training

Hi all,

I have been designing a Hyperspectral image classifier with 150 colour bands. The network breaks on the loss function with this error message

only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [4]

What i can surmise is that the targets of size 4 comes from my dataloader batch size being 4. But i can’t understand what i have done wrong to have this issue. Shouldn’t it be this way when you set a batch size.

My train function looks like

def train(self, net, 
              optimizer, 
              loss, 
              data_loader, 
              epoch = x["HYPERPARAMETERS"]["epoch"], 
              scheduler=None,
              display_iter=100, display=None,
              val_loader=None):
        """
        Training loop to optimize a network for several epochs and a specified loss

        Args:
            net: a PyTorch model
            optimizer: a PyTorch optimizer
            data_loader: a PyTorch dataset loader
            epoch: int specifying the number of training epochs
            criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss
            device (optional): torch device to use (defaults to CPU)
            display_iter (optional): number of iterations before refreshing the
            display (False/None to switch off).
            scheduler (optional): PyTorch scheduler
            val_loader (optional): validation dataset
            supervision (optional): 'full' or 'semi'
        """

        if criterion is None:
            raise Exception("Missing criterion. You must specify a loss function.")

        net.to(self.device)

        save_epoch = epoch // 20 if epoch > 20 else 1


        losses = np.zeros(1000000)
        mean_losses = np.zeros(100000000)
        iter_ = 1
        loss_win, val_win = None, None
        val_accuracies = []

        for e in tqdm(range(1, epoch + 1), desc="Training the network"):
            # Set the network to training mode
            net.train()
            avg_loss = 0.

            # Run the training loop for one epoch
            for batch_idx, (data, target) in enumerate(data_loader):
                
                # Load the data into the GPU if required
                data = data.permute(0,3,1,2)   #reshape f
               
                target = target
              
                data, target = data.to(self.device), target.to(self.device)
                
                optimizer.zero_grad()
                output = net(data)
                loss = loss(output, target)
                loss.backward()
                optimizer.step()

                avg_loss += loss.item()
                losses[iter_] = loss.item()
                mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_ + 1])

                if display_iter and iter_ % display_iter == 0:
                    string = 'Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                    string = string.format(
                        e, epoch, batch_idx *
                        len(data), len(data) * len(data_loader),
                        100. * batch_idx / len(data_loader), mean_losses[iter_])
                    update = None if loss_win is None else 'append'
                    loss_win = display.line(
                        X=np.arange(iter_ - display_iter, iter_),
                        Y=mean_losses[iter_ - display_iter:iter_],
                        win=loss_win,
                        update=update,
                        opts={'title': "Training loss",
                              'xlabel': "Iterations",
                              'ylabel': "Loss"
                             }
                    )
                    tqdm.write(string)

                    if len(val_accuracies) > 0:
                        val_win = display.line(Y=np.array(val_accuracies),
                                               X=np.arange(len(val_accuracies)),
                                               win=val_win,
                                               opts={'title': "Validation accuracy",
                                                     'xlabel': "Epochs",
                                                     'ylabel': "Accuracy"
                                                    })
                iter_ += 1
                del(data, target, loss, output)




Which line of code is throwing this error?
Could you print the shapes of all tensors before calling this particular op?

Also, loss = loss(output, target) would mask the “loss function” called loss with its result, so I would recommend to use other names, e.g. loss = criterion(output, target).

I have reverted my code back to have loss = criterion : where criterion is

criterion = nn.CrossEntropyLoss()

I printed the shape of the input data, the target(label) data and the ouput shape from output = net(data)

below is the error trace.

data shape = torch.Size([4, 150, 64, 64])
target shape = torch.Size([4])
output shape = torch.Size([4, 4, 64, 64])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-34-48f9f0494335> in <module>
      7               display_iter=100,
      8                 display=None,
----> 9               val_loader=None)

<ipython-input-32-a961c92dfd2f> in train(self, net, optimizer, criterion, data_loader, epoch, scheduler, display_iter, display, val_loader)
     67                     output = net(data)
     68                     print(f"output shape = {output.shape}")
---> 69                     loss = criterion(output, target)
     70                     print(f"loss = {loss}")
     71                 elif self.supervision == 'semi':

~\AppData\Local\Continuum\anaconda3\envs\Hyperspectral_testbed\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

~\AppData\Local\Continuum\anaconda3\envs\Hyperspectral_testbed\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
    902     def forward(self, input, target):
    903         return F.cross_entropy(input, target, weight=self.weight,
--> 904                                ignore_index=self.ignore_index, reduction=self.reduction)
    905 
    906 

~\AppData\Local\Continuum\anaconda3\envs\Hyperspectral_testbed\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   1968     if size_average is not None or reduce is not None:
   1969         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1970     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   1971 
   1972 

~\AppData\Local\Continuum\anaconda3\envs\Hyperspectral_testbed\lib\site-packages\torch\nn\functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1790         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1791     elif dim == 4:
-> 1792         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1793     else:
   1794         # dim == 3 or dim > 4

RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [4]

Based on your output shape it looks like you are dealing with e.g. a segmentation use case.
If that’s the case, your target should also provide the pixel-wise class indices.

If you would like to perform a vanilla multi-class classification, your output should have the shape [batch_size, nb_classes].