Whats the meaning of RuntimeError: 1only batches of spatial targets supported

can somebody tell me what is wrong with my code?

all the images have the size of 256x256
images have 3 channels and mask has 1 channel
and 1 batch for each

I don’t understand what this error is
RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [1, 1, 256, 256]

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import os, sys, random, time
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
import re
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#model
class unet(nn.Module):
    def __init__(self):
        super(unet, self).__init__()
        self.encoder1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        self.conv_mid = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),   
                                     nn.Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder4 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder3 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder2 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder1 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.upconv4 = nn.ConvTranspose2d(1024, 1024, kernel_size=(2, 2), stride=2)
        
        self.upconv3 = nn.ConvTranspose2d(512, 512, kernel_size=(2, 2), stride=(2, 2))
        
        self.upconv2 = nn.ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
        
        self.upconv1 = nn.ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2))
        
        self.conv1x1_out = nn.Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        out = self.encoder1(x)
        out = self.pool1(out)
        out = self.encoder2(out)
        out = self.pool2(out)
        out = self.encoder3(out)
        out = self.pool3(out)
        out = self.encoder4(out)
        out = self.pool4(out)
        out = self.conv_mid(out)
        out = self.upconv4(out)
        out = self.decoder4(out)
        out = self.upconv3(out)
        out = self.decoder3(out)
        out = self.upconv2(out)
        out = self.decoder2(out)
        out = self.upconv1(out)
        out = self.decoder1(out)
        out = self.conv1x1_out(out)
        return out
    
model = unet().to(device)
class BasicDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.mriids = next(os.walk(self.imgs_dir))[2]
        self.maskids = next(os.walk(self.masks_dir))[2]

    def __len__(self):
        return len(self.mriids)

    def __getitem__(self, i):
        
        def atoi(text):
            return int(text) if text.isdigit() else text

        def natural_keys(text):
            '''
            alist.sort(key=natural_keys) sorts in human order
            http://nedbatchelder.com/blog/200712/human_sorting.html
            (See Toothy's implementation in the comments)
            '''
            return [atoi(c) for c in re.split(r'(\d+)', text) ] 
        
        self.mriids = sorted(self.mriids, key = natural_keys)
        self.maskids = sorted(self.maskids, key = natural_keys)
        
        mriidx = self.mriids[i]
        maskidx = self.maskids[i]
        
        mask_file = glob(self.masks_dir + maskidx)
        img_file = glob(self.imgs_dir + mriidx)
               
        mask = Image.open(mask_file[0])
        img = Image.open(img_file[0])
        
        mask_np = np.array(mask)
        img_np = np.array(img)
        
        mask_np = np.expand_dims(mask_np, axis=-1)
        
        mask_np = np.transpose(mask_np, (2, 0, 1))
        img_np = np.transpose(img_np, (2, 0, 1))   
        
        #print(img_np.shape)
        #print(mask_np.shape)
        
        return {'image': torch.from_numpy(img_np), 'mask': torch.from_numpy(mask_np)}
epochs = 1000
batch_size = 1
lr = 0.00001
momentum = 0.99
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
loss_func = nn.CrossEntropyLoss().to(device)
gen = BasicDataset('/home/intern/Desktop/YH/Brain_MRI/BrainMRI_train/MRI/MRI/', '/home/intern/Desktop/YH/Brain_MRI/BrainMRI_train/mask/mask/')
train_loader = DataLoader(gen, batch_size=batch_size)
total_batch = len(gen)

model.train()
for epoch in range(epochs):
    
    for batch in train_loader:
        t0 = time.time()
        avg_cost = 0.0
        
        mri = batch['image']
        true_mask = batch['mask']
        
        mri = mri.type(torch.FloatTensor)
        true_mask = true_mask.type(torch.LongTensor)
        
        mri = mri.to(device)
        true_mask = true_mask.to(device)
        
        pred_mask = model(mri)
        loss = loss_func(pred_mask, true_mask)

        optimizer.zero_grad()
        loass.backward()
        optimizer.step()
        
        avg_cost += loss / total_batch
        t1 = time.time()
    print('[Epoch:{}], cost = {}, time = {}'.format(epoch+1, avg_cost, t1-t0))
print('training Finished!')

Output

{'mask': tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]]], dtype=torch.uint8), 'image': tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 1,  ..., 1, 0, 0],
          [0, 0, 0,  ..., 1, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [3, 4, 4,  ..., 6, 5, 2],
          [4, 4, 4,  ..., 5, 6, 4],
          [3, 2, 2,  ..., 2, 1, 2]],

         [[4, 4, 3,  ..., 0, 0, 0],
          [2, 1, 1,  ..., 0, 0, 0],
          [2, 3, 3,  ..., 0, 0, 0],
          ...,
          [2, 2, 1,  ..., 1, 1, 1],
          [1, 2, 2,  ..., 2, 1, 1],
          [1, 1, 1,  ..., 1, 2, 3]]]], dtype=torch.uint8)}

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-69-a858eeed6905> in <module>()
     22 
     23         pred_mask = model(mri)
---> 24         loss = loss_func(pred_mask, true_mask)
     25 
     26         optimizer.zero_grad()

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
    914     def forward(self, input, target):
    915         return F.cross_entropy(input, target, weight=self.weight,
--> 916                                ignore_index=self.ignore_index, reduction=self.reduction)
    917 
    918 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2019     if size_average is not None or reduce is not None:
   2020         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2021     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2022 
   2023 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1838         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1839     elif dim == 4:
-> 1840         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1841     else:
   1842         # dim == 3 or dim > 4

RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [1, 1, 256, 256]

For a multi-class segmentation nn.CrossEntropyLoss expects the target as [batch_size, height, width] containing the class indices is [0, nb_classes], while your target has an additional channel dimension.
Remove it via target = target.squeeze(1).