Multi Class implementation Mask RCNN

Dear,

I’m following the tutorial:
TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL
https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

I’m able to train and test it on my own dataset with eyes. The problem I have is that the tutorial only covers 2 classes: background and people. The masks are made:
‘zeros’ = background
‘ones’ = instance number 1
‘twos’ = instance number 2

The last 2 instances in this case are the same class: People.

What do I have to change to make it work for multiple classes. I.e. I want to classify: pupils, iris, and background (3 classes).

Do I have to make separate masks for pupils and irises , i.e.:
Folder1
‘zeros’ = background
‘ones’ = iris

Folder2
‘zeros’ = background
‘ones’ = pupil

or can I identify them in 1 overall mask:
‘zeros’ = background
‘ones’ = iris
‘twos’ = pupil

The code also mentions: num_classes = 2
changing this to 3 doesn’t really makes sense because I don’t know what the setup should be.

Kind regards
Rasmus

No one?

I tried some other frame works which works but I can’t get it to work on this one.

replacing

labels = torch.ones((num_objs,), dtype=torch.int64)

with the following

labels = torch.as_tensor(obj_ids, dtype=torch.int64)

will do.
In fact, the replacement even works for two-class instance segmentation.

Hi I tried what you advised however this error is triggered

RuntimeError Traceback (most recent call last)
in ()
3 for epoch in range(num_epochs):
4 # train for one epoch, printing every 10 iterations
----> 5 train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
6 # update the learning rate
7 lr_scheduler.step()

5 frames
/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/roi_heads.py in fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
40 # the corresponding ground truth labels, to be used with
41 # advanced indexing
—> 42 sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
43 labels_pos = labels[sampled_pos_inds_subset]
44 N, num_classes = class_logits.shape

RuntimeError: copy_if failed to synchronize: cudaErrorAssert: device-side assert triggered

Working fine with me.
Yet, copy_if failed to synchronize may be due to incorrect number of classes you used to build the model, and might not be related to the above correction I am suggesting. One reason of this (copy_if failed error), which I have encountered myself, is neglecting the background in the number of classes when building the model. Hence, if you have 10 object categories that you want to detect in your problem (background is neglected), then, you’ll have to build your model to detect 11 classes; that is, the number of classes should be 11.

Hi

This is my model 20 classes + background in total 21. Still got the error

(box_predictor): FastRCNNPredictor(
(cls_score): Linear(in_features=1024, out_features=21, bias=True)
(bbox_pred): Linear(in_features=1024, out_features=84, bias=True)
)
(mask_roi_pool): MultiScaleRoIAlign()
(mask_head): MaskRCNNHeads(
(mask_fcn1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU(inplace=True)
(mask_fcn2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2): ReLU(inplace=True)
(mask_fcn3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3): ReLU(inplace=True)
(mask_fcn4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4): ReLU(inplace=True)
)
(mask_predictor): MaskRCNNPredictor(
(conv5_mask): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
(relu): ReLU(inplace=True)
(mask_fcn_logits): Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))

why don’t you just follow the tutorial - modifying-the-model-to-add-a-different-backbone?

And what should my dataset look like? I put 2 examples in my initial question above.

I had followed the tutorial but I don’t need to modified the backbone at this stage. If I cannot get decent AP with fasterRCNN as a baseline there is not point to change the backbone. I am using Pascal Voc dataset and was expecting to get similar results to the benchmark but it seems that there is something missing.

you’ll need to put

‘zeros’ = background
‘ones’ = iris
‘twos’ = pupil

into one numpy mask and you are good to go. You can then use and modify the tutorial’s PennFudanDataset class accordingly. In addition to using labels = torch.as_tensor(obj_ids, dtype=torch.int64) suggested above; and use num_classes=3. This will work out of the box.

If you don’t want to change the backbone, and opt for using resnet50 model described in the tutorial, all should go well if you use labels = torch.as_tensor(obj_ids, dtype=torch.int64). Unless there is something incorrect in your dataset class. Hence, I suggest that you revise your dataset class, best way to do this is by invoking it separately (a simple program with a main function) and check that everything is correct; better off in the debug mode (via Spyder or PyChram) to trace everything in a step-by-step bases. You might as well go over the rest of your code to make sure everything is correct. I have tried it with 60 different classes and it is working.

So I made my dataset the way you described it, I put num_classes to 3 and changed the labels line.

I now get an error in the following part of the code:

# split the color-encoded mask into a set of binary masks

  masks = mask == obj_ids[:, None, None]
  print(obj_ids[:, None, None])
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)

        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

The error that I get:

    data.reraise()
  File "anaconda3/lib/python3.7/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "multi_training.py", line 55, in __getitem__
    pos = np.where(masks[i])
TypeError: 'bool' object is not subscriptable

The obj_ids print statement gives:

[[[1]]

 [[2]]]

masks become False instead of a 2d list with false and true

What can I do? or did I made a mistake with the masks?

This is one of the masks: (note, iris = (1,1,1) pupil = (2,2,2))
nhai_right

The changes are working with the old, 2 class dataset.

I think this is a problem related to numpy, not sure which version you’re using. I am using

np.__version__
'1.17.4'

and it does not have this problem. To confirm, here’s a snapshot of one of the masks:

[False, False, False, ..., False, False, False],
False, False, False, ..., False, False, False],
False, False, False, ..., False, False, False]]])
type(masks[0][0][0])
<class 'numpy.bool_'>

the first 0 is the mask index, and the two other zeros are the xy location in the mask.

if you are using an earlier numpy version, just upgrade it and test again.

NB. Note sure why in the ONE MASK example above your pixel/mask-values show a triplet in (note, iris = (1,1,1) pupil = (2,2,2))!? Each generated mask should only have one channel .That being said, imagine the original mask that contains the labels (call it the color-encoded-mask), before converting to three binary masks, as a gray-level image with one-channel filed with 0s, 1s, and 2s. The two masks are generated from the color-encoded-mask by removing the background (the 0s) using obj_ids = np.unique(mask)[1:], followed by masks = mask == obj_ids[:, None, None].

Thank you very much! The mistake was indeed that I had still RGB images as masks.
Converting them with the following solved it.

mask = Image.open(mask_path).convert(‘L’)

I can also now save my masks with:

img2 = Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())
img2.save('1.png')

img1 = Image.fromarray(prediction[0]['masks'][1, 0].mul(255).byte().cpu().numpy())
img1.save('2.png')

Did you find any way to draw the results such as the tutorial https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html#defining-the-dataset
With the bounding boxes and the outline?

Glad it helped! This one worked for me

Do you guys have a sample of your working code? I am dealing with the same problem.

this has been up for sometime; please try something similar to the bellow __getitem___

def __getitem__(self, index):    
        img_file = self.coco.loadImgs(self.imgIds[index] )[0]                               
        image_A = Image.open(self.path2images + img_file['file_name'])        
        annIds = self.coco.getAnnIds(imgIds=img_file['id'], catIds=self.catIds, 
                                     iscrowd=None) # suppose all instances are not crowd        
        
        anns = self.coco.loadAnns(annIds)   
        num_objs=len(anns)
                
        boxes=[]; labels=[]; area=[]               
        masks = np.zeros((num_objs, img_file['height'], img_file['width'] ) ) # just getting the shape of the mask
        for i in range(num_objs):
            labels.append(anns[i]['category_id'])            
            masks[i,:,:] = self.coco.annToMask(anns[i])            
            # boxes.append(anns[i]['bbox']) # seems there is a problem in the boxes 
            # area.append(anns[i]['area']) # and areas
              
        '''  I am calculating the bboxes and areas from the masks
             as they might be incorrect, I've had a nan in maskrcnn's loss, 
             then after checking, seems that the area does not conform 
             with the bounding boxes. 
             But this is a problem of the dataset I'm using. 
            Feel free to skim if you're sure about your dataset.
        '''

        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])        
      
        boxes = torch.as_tensor(boxes, dtype=torch.float32)      
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        target = {}                
        target["boxes"]=  boxes
        target["labels"] = torch.as_tensor(labels, dtype=torch.int64) 
        target["masks"] = torch.as_tensor(masks, dtype=torch.uint8)
        target["image_id"] = torch.tensor([index]) 
        target["area"] =  area
        target["iscrowd"] = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd             

        if self.transforms != None:
            image_A = self.transforms(image_A)                
        
        return image_A, target
   

I am sorry, i think i am just an idiot if i follow the tutorial from TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 1.10.0+cu102 documentation where do i have to make changes to add more classes for the mask rcnn model. I have a dataset containing png masks and trying to segment two classes 1. legs, 2. A4 paper. I am spinning in circle for few days to get a more straight forward answer and i can’t really find anything…

P.S the COLAB version from top of the page is not working

You can manually check if your dataset class (let it be denoted as TheeDatasetClass) is working correctly before using Mask RCNN, this is what I usually do. Just create an instance of the dataset, and verify if you are getting the correct image, boxes, etc. Something like:

TheeDatasetClass x
zz = x[1]

So, if you are Spyder or PyCharm, you can check if zz has what you wanted.

Tip: Always use debugging to make sure your code is working as you want it to, and this is why PyTorch is very handy, among other things. For example, if your dataset is causing a problem due to some out of your will reason, just hoover a for-loop over the entire dataset and check which indices are causing the problem.

import os
import numpy as np
import torch
import torch.utils.data
from PIL import Image

class PennFudanDataset(torch.utils.data.Dataset):
def init(self, root, transforms=None):
self.root = root
self.transforms = transforms

load all image files, sorting them to

ensure that they are aligned

self.imgs = list(sorted(os.listdir(os.path.join(root, “Images”))))
self.masks = list(sorted(os.listdir(os.path.join(root, “Masks”))))

def __getitem__(self, idx):
    # load images ad masks
    img_path = os.path.join(self.root, "Images", self.imgs[idx])
    mask_path = os.path.join(self.root, "Masks", self.masks[idx])
    img = Image.open(img_path).convert("RGB")
    # note that we haven't converted the mask to RGB,
    # because each color corresponds to a different instance
    # with 0 being background
    mask = Image.open(mask_path)

    mask = np.array(mask)
    # instances are encoded as different colors
    obj_ids = np.unique(mask)
    # first id is the background, so remove it
    obj_ids = obj_ids[1:]
    #print(obj_ids)

    # split the color-encoded mask into a set
    # of binary masks
    masks = mask == obj_ids[:, None, None]

    # get bounding box coordinates for each mask
    num_objs = len(obj_ids)
    boxes = []
    for i in range(num_objs):
        pos = np.where(masks[i])
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        boxes.append([xmin, ymin, xmax, ymax])

    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    # there is only one class
    # labels = torch.ones((num_objs,), dtype=torch.int64)
    **labels = torch.as_tensor(obj_ids, dtype=torch.int64)**
    #print('Labels Before',labels)
    **labels[labels == 3] = 2**
    #print('Labels After',labels)
    masks = torch.as_tensor(masks, dtype=torch.uint8)

    image_id = torch.tensor([idx])
    area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
    # suppose all instances are not crowd
    iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
  
    target = {}
    target["boxes"] = boxes
    target["labels"] = labels
    target["masks"] = masks
    target["image_id"] = image_id
    target["area"] = area
    target["iscrowd"] = iscrowd

    if self.transforms is not None:
        img, target = self.transforms(img, target)

    return img, target

def __len__(self):
    return len(self.imgs)

So as far as i understood to train Mask RCNN with multiple classes there are two things to change in the code.

  1. In the dataset class replace the existent labels with → labels = torch.as_tensor(obj_ids, dtype=torch.int64)
  2. Set the num_classes = number of clasess + 1(background).
    e.g, if we have two classes Pedestrian, Car than num_classes = 3

Is this correct?