RuntimeError: size mismatch (got input: [6422528], target: [802816])

I’m trying to adjust a binary segmentation U-net model, to be able to train a multi-class U-net on the German Asfalt Pavement Distress (GAPs) dataset.

Traceback (most recent call last):
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 263, in <module>
    train(train_loader, model, criterion, optimizer, validate, args)
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 124, in train
    loss = criterion(masks_probs_flat, true_masks_flat)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 1165, in forward
    label_smoothing=self.label_smoothing)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2996, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: size mismatch (got input: [6422528], target: [802816])

The code files, and the dataset are available through the following link:
https://drive.google.com/drive/folders/14NQdtMXokIixBJ5XizexVECn23Jh9aTM?usp=sharing

The following link is for the last stackoverflow question (before I change the criterion to use nn.CrossEntropyLoss). I’m totally new to pytorch, and I look forward to receiving your valuable advice.

The following is the code in “train_unet_GAPs.py”:

import torch
from torch import nn
from unet.unet_transfer import UNet16, UNetResNet
from pathlib import Path
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn.functional as F
from torch.autograd import Variable
import shutil
from data_loader import ImgDataSet
import os
import argparse
import tqdm
import numpy as np
import scipy.ndimage as ndimage

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def create_model(device, type ='vgg16'):
    if type == 'vgg16':
        print('create vgg16 model')
        model = UNet16(pretrained=True)
    elif type == 'resnet101':
        encoder_depth = 101
        num_classes = 8
        print('create resnet101 model')
        model = UNetResNet(encoder_depth=encoder_depth, num_classes=num_classes, pretrained=True)
    elif type == 'resnet34':
        encoder_depth = 34
        num_classes = 8
        print('create resnet34 model')
        model = UNetResNet(encoder_depth=encoder_depth, num_classes=num_classes, pretrained=True)
    else:
        assert False
    model.eval()
    return model.to(device)

def adjust_learning_rate(optimizer, epoch, lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def find_latest_model_path(dir):
    model_paths = []
    epochs = []
    for path in Path(dir).glob('*.pt'):
        if 'epoch' not in path.stem:
            continue
        model_paths.append(path)
        parts = path.stem.split('_')
        epoch = int(parts[-1])
        epochs.append(epoch)

    if len(epochs) > 0:
        epochs = np.array(epochs)
        max_idx = np.argmax(epochs)
        return model_paths[max_idx]
    else:
        return None

def train(train_loader, model, criterion, optimizer, validation, args):

    latest_model_path = find_latest_model_path(args.model_dir)

    best_model_path = os.path.join(*[args.model_dir, 'model_best.pt'])

    if latest_model_path is not None:
        state = torch.load(latest_model_path)
        epoch = state['epoch']
        model.load_state_dict(state['model'])
        epoch = epoch

        #if latest model path does exist, best_model_path should exists as well
        assert Path(best_model_path).exists() == True, f'best model path {best_model_path} does not exist'
        #load the min loss so far
        best_state = torch.load(latest_model_path)
        min_val_los = best_state['valid_loss']

        print(f'Restored model at epoch {epoch}. Min validation loss so far is : {min_val_los}')
        epoch += 1
        print(f'Started training model from epoch {epoch}')
    else:
        print('Started training model from epoch 0')
        epoch = 0
        min_val_los = 9999

    valid_losses = []
    for epoch in range(epoch, args.n_epoch + 1):

        adjust_learning_rate(optimizer, epoch, args.lr)

        tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size))
        tq.set_description(f'Epoch {epoch}')

        losses = AverageMeter()

        model.train()
        for i, (input, target) in enumerate(train_loader):
            input_var  = Variable(input).cuda()
            target_var = Variable(target).cuda()

            masks_pred = model(input_var)

            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat  = target_var.view(-1)
            print(masks_probs_flat.shape, true_masks_flat.shape)

            loss = criterion(masks_probs_flat, true_masks_flat)
            losses.update(loss)
            tq.set_postfix(loss='{:.5f}'.format(losses.avg))
            tq.update(args.batch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        valid_metrics = validation(model, valid_loader, criterion)
        valid_loss = valid_metrics['valid_loss']
        valid_losses.append(valid_loss)
        print(f'\tvalid_loss = {valid_loss:.5f}')
        tq.close()

        #save the model of the current epoch
        epoch_model_path = os.path.join(*[args.model_dir, f'model_epoch_{epoch}.pt'])
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'valid_loss': valid_loss,
            'train_loss': losses.avg
        }, epoch_model_path)

        if valid_loss < min_val_los:
            min_val_los = valid_loss

            torch.save({
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': valid_loss,
                'train_loss': losses.avg
            }, best_model_path)

def validate(model, val_loader, criterion):
    losses = AverageMeter()
    model.eval()
    with torch.no_grad():

        for i, (input, target) in enumerate(val_loader):
            input_var = Variable(input).cuda()
            target_var = Variable(target).long().cuda()
            

            output = model(input_var)
            loss = criterion(output, target_var)

            losses.update(loss.item(), input_var.size(0))

    return {'valid_loss': losses.avg}

def save_check_point(state, is_best, file_name = 'checkpoint.pth.tar'):
    torch.save(state, file_name)
    if is_best:
        shutil.copy(file_name, 'model_best.pth.tar')

def calc_crack_pixel_weight(mask_dir):
    avg_w = 0.0
    n_files = 0
    for path in Path(mask_dir).glob('*.*'):
        n_files += 1
        m = ndimage.imread(path)
        ncrack = np.sum((m > 0)[:])
        w = float(ncrack)/(m.shape[0]*m.shape[1])
        avg_w = avg_w + (1-w)

    avg_w /= float(n_files)

    return avg_w / (1.0 - avg_w)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('-n_epoch', default=10, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('-lr', default=0.001, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('-momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('-print_freq', default=20, type=int, metavar='N', help='print frequency (default: 10)')
    parser.add_argument('-weight_decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('-batch_size',  default=4, type=int,  help='weight decay (default: 1e-4)')
    #parser.add_argument('-batch_size',  default=2, type=int,  help='weight decay (default: 1e-4)')
    #parser.add_argument('-num_workers', default=4, type=int, help='output dataset directory')
    parser.add_argument('-num_workers', default=2, type=int, help='output dataset directory')

    parser.add_argument('-data_dir',type=str, help='input dataset directory')
    parser.add_argument('-model_dir', type=str, help='output dataset directory')
    parser.add_argument('-model_type', type=str, required=False, default='resnet101', choices=['vgg16', 'resnet101', 'resnet34'])

    args = parser.parse_args()
    os.makedirs(args.model_dir, exist_ok=True)

    DIR_IMG  = os.path.join(args.data_dir, 'images')
    DIR_MASK = os.path.join(args.data_dir, 'masks')

    img_names  = [path.name for path in Path(DIR_IMG).glob('*.jpg')]
    mask_names = [path.name for path in Path(DIR_MASK).glob('*.png')]

    print(f'total images = {len(img_names)}')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = create_model(device, args.model_type)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    #crack_weight = 0.4*calc_crack_pixel_weight(DIR_MASK)
    #print(f'positive weight: {crack_weight}')
    #criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([crack_weight]).to('cuda'))
    #criterion = nn.BCEWithLogitsLoss().to('cuda')
    criterion = nn.CrossEntropyLoss().to('cuda')

    #channel_means = [0.485, 0.456, 0.406]
    #channel_stds  = [0.229, 0.224, 0.225]
    channel_means = [0.5]
    channel_stds  = [0.5]
    train_tfms = transforms.Compose([transforms.Resize((448,448)), transforms.ToTensor(),
                                     transforms.Normalize(channel_means, channel_stds)])

    val_tfms = transforms.Compose([transforms.Resize((448,448)), transforms.ToTensor(),
                                   transforms.Normalize(channel_means, channel_stds)])

    mask_tfms = transforms.Compose([transforms.Resize((448,448)), transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])
    '''
    mask_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3,1,1)),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    '''
    dataset = ImgDataSet(img_dir=DIR_IMG, img_fnames=img_names, img_transform=train_tfms, mask_dir=DIR_MASK, mask_fnames=mask_names, mask_transform=mask_tfms)
    train_size = int(0.85*len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=False, pin_memory=torch.cuda.is_available(), num_workers=args.num_workers)
    valid_loader = DataLoader(valid_dataset, args.batch_size, shuffle=False, pin_memory=torch.cuda.is_available(), num_workers=args.num_workers)

    #model.cuda()
    model.to(torch.device("cuda:0"))

    train(train_loader, model, criterion, optimizer, validate, args)

The following is the code in “data_loader.py”:

import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
import random
from PIL import Image
import matplotlib.pyplot as plt

class ImgDataSet(Dataset):
    def __init__(self, img_dir, img_fnames, img_transform, mask_dir, mask_fnames, mask_transform):
        self.img_dir = img_dir
        self.img_fnames = img_fnames
        self.img_transform = img_transform

        self.mask_dir = mask_dir
        self.mask_fnames = mask_fnames
        self.mask_transform = mask_transform

        self.seed = np.random.randint(2147483647)

    def __getitem__(self, i):
        fname = self.img_fnames[i]
        fpath = os.path.join(self.img_dir, fname)
        img = Image.open(fpath)
        if self.img_transform is not None:
            random.seed(self.seed)
            img = self.img_transform(img)
            #print('image shape', img.shape)

        mname = self.mask_fnames[i]
        mpath = os.path.join(self.mask_dir, mname)
        mask = Image.open(mpath)
        #print('khanh1', np.min(test[:]), np.max(test[:]))
        if self.mask_transform is not None:
            mask = self.mask_transform(mask)
            #print('mask shape', mask.shape)
            #print('khanh2', np.min(test[:]), np.max(test[:]))

        return img, mask #torch.from_numpy(np.array(mask, dtype=np.int64))

    def __len__(self):
        return len(self.img_fnames)


class ImgDataSetJoint(Dataset):
    def __init__(self, img_dir, img_fnames, joint_transform, mask_dir, mask_fnames, img_transform = None, mask_transform = None):
        self.joint_transform = joint_transform

        self.img_dir = img_dir
        self.img_fnames = img_fnames
        self.img_transform = img_transform

        self.mask_dir = mask_dir
        self.mask_fnames = mask_fnames
        self.mask_transform = mask_transform

        self.seed = np.random.randint(2147483647)

    def __getitem__(self, i):
        fname = self.img_fnames[i]
        fpath = os.path.join(self.img_dir, fname)
        img = Image.open(fpath)

        mname = self.mask_fnames[i]
        mpath = os.path.join(self.mask_dir, mname)
        mask = Image.open(mpath)

        if self.joint_transform is not None:
            img, mask = self.joint_transform([img, mask])

        #debug
        # img = np.asarray(img)
        # mask = np.asarray(mask)
        # plt.subplot(121)
        # plt.imshow(img)
        # plt.subplot(122)
        # plt.imshow(img)
        # plt.imshow(mask, alpha=0.4)
        # plt.show()

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.mask_transform is not None:
            mask = self.mask_transform(mask)

        return img, mask #torch.from_numpy(np.array(mask, dtype=np.int64))

    def __len__(self):
        return len(self.img_fnames)

Hi Mohamed!

The short story is that you are using CrossEntropyLoss incorrectly.

Making some guesses about your use case, I would expect both
input_var and target_var to have shape
[nBatch = 4, nChannels = 1, 448, 448]. (You might not have
an nChannels dimension, and the values of nBatch and nChannels
aren’t really relevant to your issue.

However, your multi-class model will add an nClass = 8 dimension to
its predictions.

Therefore I would expect masks_pred to have shape
[nBatch = 4, nChannels = 1, 448, 448, nClass = 8]

At this point masks_pred and target_var are appropriate to pass into
CrossEntropyLoss (with the proviso that target_var should consist of
long class labels that run from 0 to nClass - 1 = 7).

Your mistake is that you then .view() masks_prod_flat and
true_masks_flat as one-dimensional tensors. (When you print
out their shapes you presumably get [6422528] and [802816],
respectively.) This just isn’t right for CrossEntropyLoss, hence
the error.

Best.

K. Frank

Hi K. Frank,

Thanks for your reply. Could you pls help to advise what should I revise in the code to get it running. If I commented out the .view() lines, and use loss = criterion(masks_pred, target_var) I got the following error:

Traceback (most recent call last):
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 265, in <module>
    train(train_loader, model, criterion, optimizer, validate, args)
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 125, in train
    loss = criterion(masks_pred, target_var)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 1165, in forward
    label_smoothing=self.label_smoothing)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2996, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [4, 1, 448, 448]

Hi Mohamed!

That’s my fault. I was forgetting that CrossEntropyLoss expects its input
(your masks_pred) to have its second dimension be its nClass dimension.

So if target_var has shape [4, 1, 448, 448] (and nClass = 8), you
would want masks_pred to have shape [4, 8, 1, 448, 448]. If your
model produces a masks_pred with shape [4, 1, 448, 448, 8] (i.e.,
with the nClass dimension last), you could swap the dimensions around
by calling your loss criterion as:

loss = criterion (masks_pred.permute (0, 4, 1, 2, 3), target_var)

(If this doesn’t work, first try squeeze()ing away the singleton channels
dimensions, and if it still doesn’t work, print out the shapes of masks_pred
and target_var and tell us what version of pytorch you are using so we can
be sure what we’re looking at.)

Best.

K. Frank

Hi K. Frank,
The shape of masks_pred is [4, 8, 448, 448] not [4, 8, 1, 448, 448].
The shape of target_var is [4, 1, 448, 448].

torch version is 1.11.0+cu113.

Best regards,
Mohamed

Hi Mohamed!

That clears things up.

masks_pred correctly has its nChannels dimension as its second
dimension. But target_var has that extra singleton dimension that
doesn’t line up with anything in masks_pred. (In the most common
integer categorical class labels use case, masks_pred should have
one more dimension – the nChannels dimension – than target_var.)

Just squeeze() that singleton dimension away. Consider:

>>> import torch
>>> torch.__version__
'1.11.0'
>>> _ = torch.manual_seed (2022)
>>> masks_pred = torch.randn (4, 8, 448, 448)
>>> target_var = torch.randint (8, (4, 1, 448,  448))
>>> torch.nn.CrossEntropyLoss() (masks_pred, target_var.squeeze())
tensor(2.4896)

Best.

K. Frank

Thanks K. Frank.
We are getting closer. The dimension errors are gone after I revised the loss as you suggested loss = criterion(masks_pred, target_var.squeeze())
Now, I got the following error:

Traceback (most recent call last):
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 265, in <module>
    train(train_loader, model, criterion, optimizer, validate, args)
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 125, in train
    loss = criterion(masks_pred, target_var.squeeze())
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 1165, in forward
    label_smoothing=self.label_smoothing)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2996, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: expected scalar type Long but found Float

To try to correct the above error, I revised the target_var to be as follows:

target_var = Variable(target).type(torch.LongTensor).cuda()

However, I got the following error:

../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [977,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [978,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [979,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [980,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [981,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [982,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [983,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [984,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [985,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [986,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [987,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [988,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [989,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [990,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [991,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [864,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [865,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [866,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [867,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [868,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [869,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [870,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [871,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [872,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [873,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [874,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [875,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [876,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [877,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [878,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [879,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [880,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [881,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [882,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [883,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [884,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [885,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [992,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [993,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [994,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [995,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [996,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [997,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [998,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [999,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1000,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1001,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1002,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1003,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1004,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1005,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1006,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1007,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1008,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1009,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1010,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1011,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1012,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [1013,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [657,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [658,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [659,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [660,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [661,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [662,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [663,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [664,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [665,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [666,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [667,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [668,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [669,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [670,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [671,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [800,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [801,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [802,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [803,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [804,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [805,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [806,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [807,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [808,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [809,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [810,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [811,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [812,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [813,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [814,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [815,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [816,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [817,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [818,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [819,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [820,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [821,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [288,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [289,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [290,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [291,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [292,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [293,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [294,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [295,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [296,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [297,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [298,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [299,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [300,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [301,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [302,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [303,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [304,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [305,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [306,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [307,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [308,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [309,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [465,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [466,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [467,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [468,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [469,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [470,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [471,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [472,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [473,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [474,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [475,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [476,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [477,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [478,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [3,0,0], thread: [479,0,0] Assertion `t >= 0 && t < n_classes` failed.
Traceback (most recent call last):
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 265, in <module>
    train(train_loader, model, criterion, optimizer, validate, args)
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 127, in train
    tq.set_postfix(loss='{:.5f}'.format(losses.avg))
  File "/usr/local/lib/python3.7/dist-packages/torch/_tensor.py", line 627, in __format__
    return self.item().__format__(format_spec)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

I look forward to getting your advice.

When I printed print(mask.min(), mask.max()), I got the following result:

tensor(-1.) tensor(-0.9529)
tensor(-0.9922) tensor(-0.9529)
tensor(-1.) tensor(-0.9529)
tensor(-1.) tensor(-0.9451)
tensor(-0.9922) tensor(-0.9451)
tensor(-0.9922) tensor(-0.9451)
tensor(-1.) tensor(-0.9529)

I can’t understand why is that. The masks have are grayscale images with the following coding:

# Mask images are stored as grey value images.
# Coding of the grey values:
#    0 = VOID
#    1 = intact road,
#    2 = applied patch,
#    3 = pothole,
#    4 = inlaid patch,
#    5 = open joint,
#    6 = crack
#    7 = street inventory

Hi Mohamed!

First, some advice:

You can’t just modify one bit of code in order to apply to some rather
different use case and hope it works. You will almost always have to
modify other bits of your code to make them consistent with your
initial modification.

You are learning, detail by detail, that CrossEntropyLoss takes a
target that has a different form than the one used by whatever loss
criterion you were using prior to your modification.

Quoting from my earlier post:

At this point masks_pred and target_var are appropriate to pass into
CrossEntropyLoss (with the proviso that target_var should consist of
long class labels that run from 0 to nClass - 1 = 7).

Your errors are: First, target has to be of type long; and second, its values
have to range from 0 to 7.

Please take a look at the documentation for CrossEntropyLoss to see how
it is designed to be used, and verify that you are using it correctly.

(As an aside, Variable has been deprecated for quite some time
now and doesn’t do anything. Just use tensors, possibly setting
requires_grad = True, but only when appropriate.)

Best.

K. Frank

Hi Mohamed!

You say that you have masks on disk that have pixel values that run from
0 through 7. This sounds just right for your use case. (You might double
check this, though.)

But when you print out values for mask in memory, you get floating-point
values near -1.0 – a problem for your use case.

You need to step through the various layers of processing that read your
mask from disk and turn it into the tensor mask. Something is changing
those values somewhere along the line.

Best.

K. Frank

Hi K. Frank,

When I changed the line loss = criterion(masks_probs_flat, true_masks_flat) into loss = criterion(masks_pred, target_var.squeeze()), the training worked.
However, I think there is still something wrong.
The results of print(mask.min(), mask.max()) are now positive floats such as print(mask.min(), mask.max())

I think what change the values from 0 through 7 into floats are the following transforms (I think the resizing is what changes the mask values):

train_tfms = transforms.Compose([transforms.Resize((448,448)), transforms.ToTensor()])
val_tfms = transforms.Compose([transforms.Resize((448,448)), transforms.ToTensor()])
mask_tfms = transforms.Compose([transforms.Resize((448,448)), transforms.ToTensor()])

I had to add the resizing because I understood from a previous discussion that Resnet only accepts square images as input.

I appreciate all your help so far. I look forward to receiving further advice from you.