AssertionError: targets should not be none when in training mode

I am trying to train the Mask R-CNN model on my dataset using torchvision.models.detection.maskrcnn_resnet50_fpn I already prepared my annotations in a json file having the labels, bounding boxes in the shape of [x1, y1, x2, y2] and the corresponding binary mask. This is the Dataloader I am using:

import torch
import json
from torch.utils.data import Dataset
from pycocotools.coco import COCO
from PIL import Image
import os
import numpy as np
from torchvision import transforms
import Config

class CustomDataset(Dataset):
    def __init__(self, images_dir, masks_dir, json_file, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        with open(json_file) as f:
            self.data = json.load(f)
        self.transform = transform
        self.image_ids = [img['id'] for img in self.data["images"]]
        self.masks = {img['id']: np.array(Image.open(os.path.join(masks_dir, img['file_name'])).convert("L"), dtype = np.uint8) for img in self.data["images"]}
      
    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        # Get image ID
        img_id = self.image_ids[idx]
        img = next(image for image in self.data["images"] if image["id"] == img_id)
        # Load image
        image = np.array(Image.open(os.path.join(self.images_dir, img['file_name'])).convert("L"), dtype = np.float)
        mask = self.masks[img_id]
        
        # apply the transform if any
        if self.transform:
            aug = self.transform(image = image, mask = mask)
            image = aug['image']
            mask = aug['mask']
        
        annotations = [ann for ann in self.data["annotations"] if ann["image_id"] == img_id]
        # extract boxes, labels and masks from annotations
        boxes = [ann["bbox"] for ann in annotations]
        labels = [ann["category_id"] for ann in annotations]
        # convert the binary mask array to a torch tensor
        mask = torch.tensor(mask)
        # convert boxes to format [x1, y1, width, height]
        boxes = [[bbox[0], bbox[1], bbox[2], bbox[3]] for bbox in boxes]
        # convert labels to integers
        labels = [int(label) for label in labels]
        # convert image and mask to torch tensors
        image = transforms.ToTensor()(image).to(Config.DEVICE)
        # create target dictionary
        target = {"boxes": torch.tensor(boxes).to(Config.DEVICE), "labels": torch.tensor(labels).to(Config.DEVICE), "masks": mask.to(Config.DEVICE)}
        return image, target

and I am training the model using this code:

import torch
import torchvision
from torchvision import models
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import Config
from dataloader import CustomDataset
from imutils import paths
import numpy as np
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
import torch.optim.lr_scheduler as lr_scheduler

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).to(Config.DEVICE)

json_path = 'path_to_json_file'

transform = A.Compose([A.Resize(Config.Input_Height, Config.Input_Width), 
                       A.Normalize(mean=(0.0), std=(1.0))])


# Create a dataloader for the custom dataset
dataset = CustomDataset(images_dir = Config.Image_dataset_dir, masks_dir = Config.Mask_dataset_dir, 
                        json_file=json_path, transform = transform)

# Split the data into training, validation and testing sets
train_split = 0.8
val_split = 0.1
train_size = int(train_split * len(dataset))
val_size = int(val_split * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=Config.Batch_size, shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=Config.Batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=Config.Batch_size, shuffle=False)

# Define the loss function, optimizer and scheduler
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)

# Train the model
for e in tqdm(range(Config.Num_epochs)):
    totalTrainloss, totalValLoss = 0,0
    model.train()
    # loop over the training set
    for (i, (x, y)) in enumerate(train_loader):
        output = model(x)
        totalTrainloss += loss_function(output, y)
        optimizer.zero_grad()
        totalTrainloss.backward()
        optimizer.step()

    # switch off autograd
    with torch.no_grad():
        # set the model in evaluation mode
        model.eval()

        # loop over the validation set
        for (x, y) in val_loader:
            # make the predictions and calculate the validation loss
            output = model(x)
            totalValLoss += loss_function(output, y)
    
    print("EPOCH: {}/{}".format(e + 1, Config.Num_epochs))
    print("Train loss: {:.4f}, Validation loss: {:.4f}".format(totalTrainloss, totalValLoss))

but the training of the model is returning the mentioned error in the title: AssertionError: targets should not be none when in training mode

I tried to loop over the train_loader and print out image and target, and it returns the values from the dataloader as expected.

Any hint what’s going wrong here please?

During training this detection model expects to receive the input data as well as the target as indicated in the error message.
Something like this would work:

output = model(images,targets)   # Returns losses and detections

Hi @ptrblck thank you for your response!

I just did passed the target to the model in
for (i, (x, y)) in enumerate(train_loader): output = model(x, y)

but it still returned an error:

TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_10916\4011324141.py in <module>
     53     # loop over the training set
     54     for (i, (x, y)) in enumerate(train_loader):
---> 55         output = model(x, y)
     56         totalTrainloss += loss_function(output, y)
     57         optimizer.zero_grad()

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~\Anaconda3\lib\site-packages\torchvision\models\detection\generalized_rcnn.py in forward(self, images, targets)
     63             else:
     64                 for target in targets:
---> 65                     boxes = target["boxes"]
     66                     if isinstance(boxes, torch.Tensor):
     67                         torch._assert(

TypeError: string indices must be integers

Given that i have a separate bounding box for each object i am trying to detect in my images (i.e. If an image has let say 3 object, 3 different records are made for the same image with the label and the coordinates of each bounding box) Could that be the problem?

Based on the error message it seems your target object is a string while a dict is expected:

target = "this is a string"
target["boxes"]
# TypeError: string indices must be integers

Take a look at this tutorial to see which objects are expected.

Hi again @ptrblck … that’s very strange indeed, because I had a look on the tutorial and double checked with the returned target from my dataloader, and everything seems to be correct! That was the output while looping over the train_loader:

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=Config.Batch_size, shuffle=True)

for x, y in train_loader:
    print(y["boxes"])

Output:

tensor([[[  2., 488., 750.,  79.],
         [138., 471., 371., 496.],
         [197., 549., 239., 156.]]])
tensor([[[514., 530., 456.,  16.]]])
tensor([[[232., 387., 266., 126.],
         [  3., 493., 503., 245.],
         [554., 442., 133., 103.]]])
tensor([[[  5., 530., 368.,  31.]]])
tensor([[[  3., 110., 625., 352.],
         [  2., 199., 390.,  95.],
         [  3., 365., 599.,  48.],
         [  3., 438., 593.,  57.],
         [  2., 323., 356.,  58.]]])
tensor([[[  1., 470., 557.,  76.]]])
tensor([[[  1., 436., 295.,  36.],
         [  0., 506., 596.,  99.]]])
tensor([[[  1., 109., 629., 358.],
         [  3., 201., 386.,  91.],
         [  3., 362., 594.,  51.],
         [  1., 372., 353.,  72.],
         [  3., 435., 588.,  65.]]])
tensor([[[160., 478., 349.,  39.],
         [162., 477., 560.,  76.],
         [241., 460., 240., 348.],
         [325., 609., 281., 198.],
         [564., 510., 196., 298.]]])

So the returned target is of a dict type. Any idea what could be wrong here, please? Thank you so much in advance

Sorry, I was wrong and the expected target type seems to be a list of dicts.
If you iterate the dict directly you would access its keys which raises the error:

# fails
targets = {"boxes": torch.randn(10, 10)}
for target in targets:
    print(target)
    # boxes
    target["boxes"]
    # TypeError: string indices must be integers
    
# works
targets = [{"boxes": torch.randn(10, 10)}]
for target in targets:
    print(target)
    # {'boxes': tensor([[ 1.3441, -1.2482, -1.0498, -1.1825,  0.0875,  0.5161, -0.7800, -0.7537,
    #           0.7275,  2.3111],
    target["boxes"]
1 Like

Thank you so much @ptrblck for your clarification. Now I managed to overcome this issue, but another one popped up! after returning a list of dicts as you recommended, this error is appearing:

AssertionError: Expected target boxes to be a tensor of shape [N, 4], got torch.Size([1, 3, 4]).

Most probably the first dimension added to boxes is the batch size after creating the dataloader, I tried to return torch.squeeze(boxes, dim=0) but that apparantely does not seem to be working. How can this problem be solved, please?

torch.squeeze only squeeze the dimension with size of 1.
For example,

a = torch.randn(1, 3, 4)
torch.squeeze(a, dim=0) # return a tensor dimension of (3, 4)

b = torch.randn(2, 3, 4)
torch.squeeze(b, dim=0) # return a tensor dimension of (2, 3, 4)

Consequently, you should consider the case of batch size is not 1.

boxes = boxes.view(-1, 4)

will work

Thank you @thecho7 for your answer! When i try to loop over the dataset object created it returns the correct shape expected by the network [N,4] as follows:

dataset = CustomDataset(images_dir = Config.Image_dataset_dir, masks_dir = Config.Mask_dataset_dir, 
json_file=json_path, transform = transform)

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=Config.Batch_size, shuffle=True)

for x, y in train_dataset:
    print(y[0]["boxes"].shape)

output is:

torch.Size([3, 4])
torch.Size([5, 4])
torch.Size([1, 4])
torch.Size([5, 4])
torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([3, 4])
torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([3, 4]) 
.......

But when i try to print out the shapes of the boxes in the dataloader, as follows:

for x, y in train_loader:
    print(y[0]["boxes"].shape)

it returns:

torch.Size([1, 1, 4])
torch.Size([1, 1, 4])
torch.Size([1, 1, 4])
torch.Size([1, 5, 4])
torch.Size([1, 3, 4])
.............

Given that i am using a batch_size = 1 so it’s added to the dimensions of the train_loader! I tried to return boxes after boxes = boxes.view(-1, 4) but it’s still returning the same shape. How can such issue be fixed?

I assume the list should contain batch_size tensors each containing all bounding boxes, so you might need to revisit your current workflow.
As @thecho7 described, squeeze should work:

# works
targets = [{"boxes": torch.randn(1, 10, 10)}]
for target in targets:
    print(target)
    # {'boxes': tensor([[ 1.3441, -1.2482, -1.0498, -1.1825,  0.0875,  0.5161, -0.7800, -0.7537,
    #           0.7275,  2.3111],
    print(target["boxes"].shape)
    # torch.Size([1, 10, 10])
    boxes = target["boxes"].squeeze(0)
    print(boxes.shape)
    # torch.Size([10, 10])

Hi again @ptrblck and @thecho7 I managed to solve this problem by implementing a collate_fn that’s stacking all the information of the image and target in the batch. Which is as follows:

def collate_fn(batch):
    images, targets = [], []
    for sample in batch:
        images.append(sample[0])
        targets.append(sample[1])
    return torch.stack(images), targets

After some readings, it seems that the model expects inputs in tuples (images, targets) where images is 3D tensors for the images in the batch and targets is a list of dictionaries, where each dictionary contains information about the object in the corresponding image. After ensuring that the images, masks, bounding boxes, labels, and image_id returned from the dataloader are in the right shape, and are aligned together, I tried to run the training again using the code from this tutorial but now i am getting a whole new error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_14092\2448976667.py in <module>
    109 for epoch in range(num_epochs):
    110     # train for one epoch, printing every 10 iterations
--> 111     train_one_epoch(model = model, optimizer = optimizer, 
    112                     data_loader = data_loader, device = Config.DEVICE, epoch = epoch, print_freq=10)
    113     # update the learning rate

~\Desktop\Master Thesis Project\Mask RCNN\engine.py in train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler)
     29         targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
     30         with torch.cuda.amp.autocast(enabled=scaler is not None):
---> 31             loss_dict = model(images, targets)
     32             losses = sum(loss for loss in loss_dict.values())
     33 

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~\Anaconda3\lib\site-packages\torchvision\models\detection\generalized_rcnn.py in forward(self, images, targets)
     81             original_image_sizes.append((val[0], val[1]))
     82 
---> 83         images, targets = self.transform(images, targets)
     84 
     85         # Check for degenerate boxes

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~\Anaconda3\lib\site-packages\torchvision\models\detection\transform.py in forward(self, images, targets)
    134 
    135         image_sizes = [img.shape[-2:] for img in images]
--> 136         images = self.batch_images(images, size_divisible=self.size_divisible)
    137         image_sizes_list: List[Tuple[int, int]] = []
    138         for image_size in image_sizes:

~\Anaconda3\lib\site-packages\torchvision\models\detection\transform.py in batch_images(self, images, size_divisible)
    231             return self._onnx_batch_images(images, size_divisible)
    232 
--> 233         max_size = self.max_by_axis([list(img.shape) for img in images])
    234         stride = float(size_divisible)
    235         max_size = list(max_size)

~\Anaconda3\lib\site-packages\torchvision\models\detection\transform.py in max_by_axis(self, the_list)
    219 
    220     def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
--> 221         maxes = the_list[0]
    222         for sublist in the_list[1:]:
    223             for index, item in enumerate(sublist):

IndexError: list index out of range

This error suggests that the image is empty, which is impossible given all the checks I have been doing for 1 whole day now. Any hint what could be wrong here, please? This is driving me insane

Could you add a few debug print statements to the GeneralizedRCNNTransform.forward in ~\Anaconda3\lib\site-packages\torchvision\models\detection\transform.py at line 107+ and check what images is before being passed to self.batch_images?

I printed the images passed to the train_one_epoch function and it was not passed in the right format, the image was printed as generator object train_one_epoch.<locals>.<genexpr> at 0x0000019FD323EE40> so i traced that back in references/detection/engine.py file and there was a problem here:

for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = (image.to(device) for image in images)

so i changed it to be:

for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)

to return a list of tensors as input for the model. and that worked, the training is working now, but for just one loop, then it outputs NaN losses, as follows:

Epoch: [0]  [   0/2759]  eta: 13:32:17  lr: 0.000010  loss: -136.8811 (-136.8811)  loss_classifier: 0.9397 (0.9397)  loss_box_reg: 0.0017 (0.0017)  loss_mask: -137.9142 (-137.9142)  loss_objectness: 0.0859 (0.0859)  loss_rpn_box_reg: 0.0057 (0.0057)  time: 17.6648  data: 10.2502
Loss is nan, stopping training
{'loss_classifier': tensor(nan, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(nan, grad_fn=<DivBackward0>), 'loss_mask': tensor(nan, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_objectness': tensor(nan, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(nan, grad_fn=<DivBackward0>)}
An exception has occurred, use %tb to see the full traceback.

SystemExit: 1

The error is raised from this part of the train_epoch_loss function:

# reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            print(loss_dict_reduced)
            sys.exit(1)

And i really have no idea what is wrong anymore. Any help with that, please?

Based on your output it seems the very first iteration already creates the NaN outputs?
If so, could you add more debug print statements into the forward method(s) of your model(s) and check where the invalid values are created? Forward hooks should also work to get you an idea which layer created these.
Something like this should already point out the module causing the issue:

def nan_hook(name):
    def hook(m, input, output):
        if not torch.isfinite(output).all():
            print("Invalid output in {}".format(name))
    return hook

for name, module in model.named_modules():
    module.register_forward_hook(nan_hook(name))

Yes, that’s true that the NaN values are resulting in the first epoch. The problem is, i didn’t implement the model myself, instead, i am using the torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) pretrained model to finetune it for detecting objects on my own dataset, so i am not really sure where to add this hook function. For training, I am using the code from here. Can you please tell me what to do in this case to debug this problem?

Add this code before calling the forward pass of your model e.g. as seen here:

def nan_hook(name):
    def hook(m, input, output):
        if torch.is_tensor(output):
            if not torch.isfinite(output).all():
                print("Invalid output in {}".format(name))
        elif isinstance(output, tuple):
            for o in output:
                if torch.is_tensor(o):
                    if not torch.isfinite(o).all():
                        print("Invalid output in {}".format(name))
    return hook


model = torchvision.models.detection.maskrcnn_resnet50_fpn()
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]

for name, module in model.named_modules():
    module.register_forward_hook(nan_hook(name))

predictions = model(x) # no output

# create invalid inputs
x[0].fill_(torch.log(torch.tensor(-1.)))
predictions = model(x)
# Invalid output in backbone.body.conv1
# Invalid output in backbone.body.bn1
# Invalid output in backbone.body.relu
# Invalid output in backbone.body.maxpool
# Invalid output in backbone.body.layer1.0.conv1
# Invalid output in backbone.body.layer1.0.bn1
# ...

I have adapted the code a bit to be able to check tuples as well and you can improve it further if needed.

1 Like

Hi @ptrblck thank you so much for your help! The problem after debugging the training code was mainly that my masks values are not in the range [0,1] instead it was in [0,255] so that was causing the gradients to explode and simply return NaN. Now the training is working fine, but after some iterations over the batches in the first epoch, this error is raised:

AssertionError: All bounding boxes should have positive height and width. Found invalid box [490.3699035644531, 528.8463134765625, 490.3699035644531, 643.4569702148438] for target at index 0.

I looked this problem up, and some colleagues suggested that the data format of the bbox is wrong, and it should be passed as [Xmin, Ymin, Xmax, Ymax] but I am already organizing the coordinates in the list with this format, and to double ensure that everything is passed correctly, I am returning

boxes = [[min(bbox[0], bbox[2]), min(bbox[1], bbox[3]), max(bbox[0], bbox[2]), max(bbox[1], bbox[3])] for bbox in boxes]

from the dataloader, but it seems that there’s an incidence of the data that has equal Xmin and Xmax which is causing the problem, how can this be tackled? Thank you so much in advance!

Your explanation is correct and the error is most likely raised since your bounding box if empty since the xmin/xmax values are equal.
Remove these empty bounding boxes from the training set or make sure to expand the box in case you think it is tagging a valid target.