Loss decreasing but predictions are empty

Hello everyone,
I am currently implementing a UNet3D mo del for a Brain Tumor Segmentation task using the 2018 dataset but for some reason when training, the loss decreases but the predictions are all classified as background making all the metrics have a 0 score. The model I am using is the UNet3D from this repo: Unet3D.
Here is the code I am using for the training of the network:

def train(args, logging_path):
  base_lr = args.base_lr
  num_channels = args.num_channels
  num_classes = args.num_classes
  batch_size = args.batch_size
  max_iterations = args.max_iterations

def create_model():
    #return Swin(in_channel=num_channels, num_classes=num_classes, window_size=(4,4,4)).cuda()
    return UNet3D(in_channels=num_channels, n_classes=num_classes).cuda()
model = create_model()
ModelParamInit(model)

def worker_init_fn(worker_id): 
    random.seed(args.seed + worker_id)

labeled_train_loader = BraTS2018(base_dir=args.data_path,
               split='train',num=None, img_type='all_modalities', transform=transforms.Compose([RandomRotFlip(), RandomCrop(args.patch_size), ToTensor() ]))

trainloader_labeled = DataLoader(labeled_train_loader,
                                 batch_size=args.batch_size, num_workers=4,
                                 pin_memory=True, worker_init_fn=worker_init_fn)


model.train()

adam_optimizer = optim.Adam(model.parameters(), lr=base_lr, weight_decay=0.0001)
ce_loss = CrossEntropyLoss()

logging.info("{} iterations per epoch".format(len(trainloader_labeled)))
iter_num = 0
max_epoch = max_iterations // len(trainloader_labeled) + 1
best_performance1 = 0.0
iterator = tqdm(range(max_epoch), ncols=70)

for epoch_num in iterator:
    for i_batch, sampled_batch in enumerate(trainloader_labeled):

        volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 
        volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
        
        output = model(volume_batch)
        loss = ce_loss(output, label_batch.long())


        adam_optimizer.zero_grad()
        loss.backward()
        adam_optimizer.step()

        iter_num = iter_num + 1

        lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9

        for param_group in adam_optimizer.param_groups:
            param_group['lr'] = lr_

        logging.info('iteration %d : model loss : %f' % (iter_num, loss.item()))
        
        if iter_num > 0 and iter_num % 500 == 0:
            model.eval()
            metric_list = 0.0

            metric_list = test_all_case(model, args.data_path, img_type="all_modalities", num_classes=num_classes, patch_size=args.patch_size, stride_xy=64, stride_z=64)

            logging.info(
                    'Core: Dice Coefficient: %f Hausdorff Distance: %f' % (metric_list[0,0], metric_list[0,1]))
            logging.info(
                    'Edema: Dice Coefficient: %f Hausdorff Distance: %f' % (metric_list[1,0], metric_list[1,1]))
            logging.info(
                    'Enhancing: Dice Coefficient: %f Hausdorff Distance: %f' % (metric_list[2,0], metric_list[2,1]))

            model.train()


        if iter_num % 1000 == 0:
            save_mode_path = os.path.join(
                logging_path, 'iter_' + str(iter_num) + '.pth')
            torch.save(model.state_dict(), save_mode_path)
            logging.info("save model to {}".format(save_mode_path))

        if iter_num >= max_iterations:
            break
    if iter_num >= max_iterations:
        iterator.close()
        break
return "Training Finished!"

if name == “main”:

cudnn.benchmark = True
cudnn.deterministic = False

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

logging_path ="./Logs/Iterations"

print(f"args.exp: {args.exp}, args.model: {args.model}, snapshot_path: {logging_path}")

logging.basicConfig(filename="./Logs/log.txt", level=logging.INFO,
                    format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
train(args, logging_path)

`

Am I missing something very basic?
Thanks in advance,
Rodrigo

I’m not familiar with how your dataset is implemented, but one possibility could be that the label format does not match the output format of the model (which appears to be one-hot), and this could cause unwanted broadcasting which would cause the model to predict the “mean” of the dataset (possibly background in this case).

Could you check if the label shape is matching the output shape of the model for each iteration and that the labels are in the range [0, 1]?

This is basically my loader. It just loads a .h5 file with the 1 channel for each one the 4 modalities and applies some data transformation.

class BraTS2018(Dataset):
    """ BraTS2018 Dataset """

    def __init__(self, base_dir=None, img_type=None, split='train', num=None, transform=None):
        self._base_dir = base_dir
        self.transform = transform
        self.sample_list = []
        self.split = split
        self.img_type = img_type

        print(self._base_dir)
        train_path = self._base_dir+'/train.list'
        val_path = self._base_dir+'/val.list'
        test_path = self._base_dir+'/test.list'

        if self.split == 'train':
            with open(train_path, 'r') as f:
                self.image_list = f.readlines()

        elif self.split == 'val':
            with open(val_path, 'r') as f:
                self.image_list = f.readlines()

        elif self.split == 'test':
            with open(test_path, 'r') as f:
                self.image_list = f.readlines()

        self.image_list = [item.replace('\n', '').split(",")[0] for item in self.image_list]
        if num is not None:
            self.image_list = self.image_list[:num]
        print("total {} samples".format(len(self.image_list)))

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        case = self.image_list[idx]

        h5f = h5py.File(self._base_dir + "/{}.h5".format(case), "r")
        flair, t1, t2, ce, seg = h5f['image'][:], h5f['t1'][:], h5f['t2'][:] , h5f['ce'][:], h5f['label'][:]

        input_image = np.zeros((4,flair.shape[0], flair.shape[1], flair.shape[2]))
        input_image[0,:,:,:] = flair
        input_image[1,:,:,:] = t1
        input_image[2,:,:,:] = t2
        input_image[3,:,:,:] = ce

        seg[seg==4]=3

        # image, label = flair[:], seg[:]
        sample = {'image': input_image, 'label': seg.astype(np.uint8), 'name': case}
 
        if self.transform:
            sample = self.transform(sample)
        return sample


class CenterCrop(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # pad the sample if necessary
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
                self.output_size[2]:
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)],
                           mode='constant', constant_values=0)
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)],
                           mode='constant', constant_values=0)

        (w, h, d) = image.shape

        w1 = int(round((w - self.output_size[0]) / 2.))
        h1 = int(round((h - self.output_size[1]) / 2.))
        d1 = int(round((d - self.output_size[2]) / 2.))

        label = label[w1:w1 + self.output_size[0], h1:h1 +
                      self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[w1:w1 + self.output_size[0], h1:h1 +
                      self.output_size[1], d1:d1 + self.output_size[2]]

        return {'image': image, 'label': label}


class RandomCrop(object):
    """
    Crop randomly the image in a sample
    Args:
    output_size (int): Desired output size
    """

    def __init__(self, output_size, with_sdf=False):
        self.output_size = output_size
        self.with_sdf = with_sdf

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # pad the sample if necessary
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
                self.output_size[2]:
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)

            new_array = np.zeros((4, label.shape[0]+2*pw, label.shape[1]+2*ph, label.shape[2]+2*pd))
            new_array[0,:,:,:] = np.pad(image[0,:,:,:], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            new_array[1,:,:,:] = np.pad(image[1,:,:,:], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            new_array[2,:,:,:] = np.pad(image[2,:,:,:], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            new_array[3,:,:,:] = np.pad(image[3,:,:,:], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
 
            image = new_array

        (_, w, h, d) = image.shape
        # if np.random.uniform() > 0.33:
        #     w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4)
        #     h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4)
        # else:
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        label = label[w1:w1 + self.output_size[0], h1:h1 +
                      self.output_size[1], d1:d1 + self.output_size[2]]

        # label[] = label[w1:w1 + self.output_size[0], h1:h1 +
        #               self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[:, w1:w1 + self.output_size[0], h1:h1 +
                      self.output_size[1], d1:d1 + self.output_size[2]]

        return {'image': image, 'label': label}


class RandomRotFlip(object):
    """
    Crop randomly flip the dataset in a sample
    Args:
    output_size (int): Desired output size
    """

    def __call__(self, sample):
        flair, t1, t2, ce, label = sample['image'][0,:,:,:], sample['image'][1,:,:,:], sample['image'][2,:,:,:], sample['image'][3,:,:,:], sample['label']

        k = np.random.randint(0, 4)
        flair = np.rot90(flair, k)
        ce = np.rot90(ce, k)
        t1 = np.rot90(t1, k)
        t2 = np.rot90(t2, k)
        label = np.rot90(label, k)

        axis = np.random.randint(0, 2)
        flair = np.flip(flair, axis=axis).copy()
        ce = np.flip(ce, axis=axis).copy()
        t1 = np.flip(t1, axis=axis).copy()
        t2 = np.flip(t2, axis=axis).copy()
        label = np.flip(label, axis=axis).copy()

        input_image = np.zeros((4,flair.shape[0], flair.shape[1], flair.shape[2]))
        input_image[0,:,:,:] = flair
        input_image[1,:,:,:] = t1
        input_image[2,:,:,:] = t2
        input_image[3,:,:,:] = ce  
             
        return {'image': input_image, 'label': label}



class RandomNoise(object):
    def __init__(self, mu=0, sigma=0.1):
        self.mu = mu
        self.sigma = sigma

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        noise = np.clip(self.sigma * np.random.randn(
            image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma)
        noise = noise + self.mu
        image = image + noise
        return {'image': image, 'label': label}


class CreateOnehotLabel(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        onehot_label = np.zeros(
            (self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32)
        for i in range(self.num_classes):
            onehot_label[i, :, :, :] = (label == i).astype(np.float32)
        return {'image': image, 'label': label, 'onehot_label': onehot_label}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image = sample['image']

        return {'image': torch.from_numpy(image).float(), 'label': torch.from_numpy(sample['label']).long()}


class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices

    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """

    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in zip(grouper(primary_iter, self.primary_batch_size),
                   grouper(secondary_iter, self.secondary_batch_size))
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size


def iterate_once(iterable):
    return np.random.permutation(iterable)


def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())


def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)

Regarding the output shape and the label shape, the prediction of the model has a shape of (2,4,128,128,128), (B, C, H, W,D), respectively, and the label has a shape of (2,128,128,128), (B, H, W, D).
Is this a problem when calculating the Cross Entropy loss?

I think those shapes could be OK, according to the docs, but do the labels have values in the range [0, 3), according to :

Target: If containing class indices, shape ()(), (N)(N) or (N,d1,d2,...,dK)(N,d1​,d2​,...,dK​) with K≥1K≥1 in the case of K-dimensional loss where each value should be between [0,C)[0,C). If containing class probabilities, same shape as the input and each value should be between [0,1][0,1].

I would also try to verify that your model can learn non-background trivially small synthetic datasets (where e.g., you can artificially balance the data classes).