How to get mask scores for image segmentation in Detectron2?

Hi, I’m trying to use Detectron2 to extract masks for image segmentation using Mask-RCNN. I used the command:

outputs = predictor(im) where predictor is a DefaultPredictor

However, the output has a field called pred_masks which returns only True or False values, while I want it to return a value from 0 to 1 in each pixel (from what I understand while reading the Mask-RCNN paper, it is supposed to return from 0 to 1). Please help me, thank you very much.

outputs should be a dict containing panoptic_seg, sem_seg or instances as seen here. Could you post a minimal code snippet to reproduce the issue?

The thing is that outputs is from a DefaultPredictor which provides only pred_masks (binary mask) which is a field in instances. But I want a score associated with each pixel.

Here is the example:

for img_name in img_list:
im = cv2.imread(os.path.join(TEST_DIR, img_name))
# if img_name == ‘438.jpg’:
# pdb.set_trace()
# pdb.set_trace()
outputs = predictor(im)
mask = outputs[‘instances’].pred_masks.to(‘cpu’).numpy()
mask = mask.astype(np.uint8)
pdb.set_trace()
mask[mask > 0] = 255
(N, H, W) = mask.shape
if N == 1:
mask = mask.squeeze()
else:
# Many objects
mask1 = np.zeros((H, W), dtype=np.uint8)
for i in range(N):
mask1[mask[i, :] > 0] = 255
mask = mask1
cv2.imwrite(os.path.join(OUTPUT_DIR, img_name), mask)

num += 1

Your code snippet is unfortunately not executable. Could you use random inputs and define all classes, so that we could run the code and debug it, please?

Here is the code (I used from Google Colab):

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import cv2
import random
import pdb
import matplotlib as mpl
mpl.use(‘Agg’)
import matplotlib.pyplot as plt

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(“COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml”))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(“COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml”)
predictor = DefaultPredictor(cfg)
im = cv2.imread(“./input.jpg”)
outputs = predictor(im)
mask = outputs[‘instances’].pred_masks.to(‘cpu’).numpy()
pdb.set_trace()
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs[“instances”].to(“cpu”))
cv2.imwrite(‘./out.png’, v.get_image()[:, :, ::-1])

The image can be downloaded with:

wget http://images.cocodataset.org/val2017/000000439715.jpg -O input.jpg

The problem is that the mask variable only returns either 0 or 1, but I want the score for each pixel from 0 to 1 (because the Mask-RCNN paper returns so). Thank you very much.

So after a while, I was able to figure out a makeshift solution (by looking at a number of places: , github .com / facebookresearch /detectron2/blob/master/detectron2/modeling/postprocessing.py, detectron2/detectron2/layers/mask_ops.py at 98739e17ca01a46fe1da2db9ec09ed03564088bd · facebookresearch/detectron2 · GitHub). I copied the code from those links and modified only the part where the thresholding is involved (which is the last part). Here is the full code:

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import some common libraries

import numpy as np
import cv2
import random
import pdb
import matplotlib as mpl
mpl.use(‘Agg’)
import matplotlib.pyplot as plt
import torch

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling import build_model
import detectron2.data.transforms as T
from detectron2.structures import ImageList
from detectron2.modeling.backbone.build import build_backbone
from torch.nn import functional as F
from detectron2.structures import Instances
from detectron2.utils.memory import retry_if_cuda_oom
from PIL import Image

BYTES_PER_FLOAT = 4

TODO: This memory limit may be too much or too little. It would be better to

determine it based on available resources.

GPU_MEM_LIMIT = 1024 ** 3 # 1 GB memory limit

def preprocess_image(batched_inputs, cfg):
“”"
Normalize, pad and batch the input images.
“”"
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1)
backbone = build_backbone(cfg)

images = [x["image"].to(pixel_mean.device) for x in batched_inputs]
images = [(x - pixel_mean) / pixel_std for x in images]
images = ImageList.from_tensors(images, backbone.size_divisibility)
return images

def custom_postprocess(instances, batched_inputs, image_sizes):
“”"
Rescale the output instances to the target size.
“”"
# note: private function; subject to changes
processed_results =
for results_per_image, input_per_image, image_size in zip(
instances, batched_inputs, image_sizes
):
height = input_per_image.get(“height”, image_size[0])
width = input_per_image.get(“width”, image_size[1])
r = detector_postprocess(results_per_image, height, width)
processed_results.append({“instances”: r})
return processed_results

def detector_postprocess(results, output_height, output_width, mask_threshold=0.5):
“”"
Resize the output instances.
The input images are often resized when entering an object detector.
As a result, we often need the outputs of the detector in a different
resolution from its inputs.
This function will resize the raw outputs of an R-CNN detector
to produce outputs according to the desired output resolution.
Args:
results (Instances): the raw outputs from the detector.
results.image_size contains the input image resolution the detector sees.
This object might be modified in-place.
output_height, output_width: the desired output resolution.
Returns:
Instances: the resized output from the model, based on the output resolution
“”"
scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0])
results = Instances((output_height, output_width), **results.get_fields())

if results.has("pred_boxes"):
    output_boxes = results.pred_boxes
elif results.has("proposal_boxes"):
    output_boxes = results.proposal_boxes

output_boxes.scale(scale_x, scale_y)
output_boxes.clip(results.image_size)

results = results[output_boxes.nonempty()]

if results.has("pred_masks"):
    results.pred_masks = retry_if_cuda_oom(paste_masks_in_image)(
        results.pred_masks[:, 0, :, :],  # N, 1, M, M
        results.pred_boxes,
        results.image_size,
        threshold=mask_threshold,
    )

if results.has("pred_keypoints"):
    results.pred_keypoints[:, :, 0] *= scale_x
    results.pred_keypoints[:, :, 1] *= scale_y

return results

def custom_do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
“”"
Args:
masks: N, 1, H, W
boxes: N, 4
img_h, img_w (int):
skip_empty (bool): only paste masks within the region that
tightly bound all boxes, and returns the results this region only.
An important optimization for CPU.
Returns:
if skip_empty == False, a mask of shape (N, img_h, img_w)
if skip_empty == True, a mask of shape (N, h’, w’), and the slice
object for the corresponding region.
“”"
# On GPU, paste all masks together (up to chunk size)
# by using the entire image to sample the masks
# Compared to pasting them one by one,
# this has more operations but is faster on COCO-scale dataset.
device = masks.device
if skip_empty:
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
dtype=torch.int32
)
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
else:
x0_int, y0_int = 0, 0
x1_int, y1_int = img_w, img_h
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1

N = masks.shape[0]

img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)

gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)

img_masks = F.grid_sample(masks.to(dtype=torch.float32), grid, align_corners=False)

if skip_empty:
    return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
else:
    return img_masks[:, 0], ()

def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5):
“”"
Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image.
The location, height, and width for pasting each mask is determined by their
corresponding bounding boxes in boxes.
Note:
This is a complicated but more accurate implementation. In actual deployment, it is
often enough to use a faster but less accurate implementation.
See :func:paste_mask_in_image_old in this file for an alternative implementation.
Args:
masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of
detected object instances in the image and Hmask, Wmask are the mask width and mask
height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4).
boxes[i] and masks[i] correspond to the same object instance.
image_shape (tuple): height, width
threshold (float): A threshold in [0, 1] for converting the (soft) masks to
binary masks.
Returns:
img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
number of detected object instances and Himage, Wimage are the image width
and height. img_masks[i] is a binary mask for object instance i.
“”"

assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
N = len(masks)
if N == 0:
    return masks.new_empty((0,) + image_shape, dtype=torch.uint8)
if not isinstance(boxes, torch.Tensor):
    boxes = boxes.tensor
device = boxes.device
assert len(boxes) == N, boxes.shape

img_h, img_w = image_shape

# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
if device.type == "cpu":
    # CPU is most efficient when they are pasted one by one with skip_empty=True
    # so that it performs minimal number of operations.
    num_chunks = N
else:
    # GPU benefits from parallelism for larger chunks, but may have memory issue
    # int(img_h) because shape may be tensors in tracing
    num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
    assert (
        num_chunks <= N
    ), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it"
chunks = torch.chunk(torch.arange(N, device=device), num_chunks)

# img_masks = torch.zeros(
#     N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8
# )
img_masks = torch.zeros(
      N, img_h, img_w, device=device, dtype=torch.uint8
)
for inds in chunks:
    masks_chunk, spatial_inds = custom_do_paste_mask(
        masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu"
    )

    # if threshold >= 0:
    #     masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
    # else:
    # for visualization and debugging
    masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

    img_masks[(inds,) + spatial_inds] = masks_chunk
return img_masks

cfg = get_cfg()

add project-specific config (e.g., TensorMask) here if you’re not running a model in detectron2’s core library

cfg.merge_from_file(model_zoo.get_config_file(“COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml”))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model

Find a model from detectron2’s model zoo. You can use the https://dl.fbaipublicfiles… url as well

cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(“COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml”)

predictor = DefaultPredictor(cfg)

model = build_model(cfg) # returns a torch.nn.Module
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
model.eval()

im = cv2.imread(“./input.jpg”)
predictor = DefaultPredictor(cfg)
outputs = predictor(im)
mask = outputs[‘instances’].pred_masks.to(‘cpu’).numpy()
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs[“instances”].to(“cpu”))
cv2.imwrite(‘./out.png’, v.get_image()[:, :, ::-1])
plt.close()

original_image = im
original_image = original_image[:, :, ::-1]
transform_gen = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
height, width = original_image.shape[:2]
image = transform_gen.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype(“float32”).transpose(2, 0, 1))
inputs = {“image”: image, “height”: height, “width”: width}
outputs = model.inference([inputs], do_postprocess=False)
inputs_1 = preprocess_image([inputs], cfg)
image_sizes = inputs_1.image_sizes
processed_outputs = custom_postprocess(outputs, [inputs], image_sizes)

mask = outputs[‘instances’].pred_masks.to(‘cpu’).numpy()

v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)

v = v.draw_instance_predictions(outputs[“instances”].to(“cpu”))

cv2.imwrite(‘./out.png’, v.get_image()[:, :, ::-1])

mask = processed_outputs[0][‘instances’].pred_masks.to(‘cpu’).numpy()
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
processed_outputs[0][‘instances’].pred_boxes.tensor= processed_outputs[0][‘instances’].pred_boxes.tensor.detach()
processed_outputs[0][‘instances’].scores = processed_outputs[0][‘instances’].scores.detach()
pdb.set_trace()
v = v.draw_instance_predictions(processed_outputs[0][“instances”].to(“cpu”))
cv2.imwrite(‘./out2.png’, v.get_image()[:, :, ::-1])

So once we get processed_outputs, the mask can be fetched by:

mask = processed_outputs[0]‘instances’].pred_masks.to(‘cpu’).numpy()

The masks will be integers from 0 to 255 which is good enough. Here are the image from soft (non-binary) masks

I encountered the same problem and tried to resolve it with @dangmanhtruong1995 s way. But I still get the binary result. I change the threshold and check it at the end of the function of paste_masks_in_image(), it seems to be non binary, however, the predictor still gives a non binary output. Since the detectron architecture is too complex, I am lost and open to suggestions.