Reduce classes from a pre-trained model (fcn_resnet101) to optimize semantic segmentation

Hello:

I’m starting in semantic segmentation with fcn_resnet101 model. With my webcam I stream frame by frame with recognized clases but it is so slow (around 1.94 seconds per frame). I think if removing unused classes this could be a little bit fast.

The fcn_resnet101 model has 21 classes but I jus need around three or four classes from that model.

Is there some way to remove the unused classes or just ommiting them while the image analisys is running?

Here is my code

import sys
import torch
from PIL import Image
import torchvision.transforms as T
import numpy as np
from torchvision import models
import cv2
import time


# Define the helper function
def decode_segmap(image, nc=21):
    label_colors = np.array([(0, 0, 0),  # 0=background
                             # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
                             (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
                             # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
                             (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
                             # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
                             (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
                             # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
                             (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])

    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]

    rgb = np.stack([r, g, b], axis=2)
    return rgb

def segmentation(modelo,  img, resize, center_crop, device):
    trf = T.Compose([T.Resize(resize),
                     T.CenterCrop(center_crop),
                     T.ToTensor(),
                     T.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])
    inp = trf(img).unsqueeze(0)
    inp = inp.to(device)

    out = modelo(inp)['out']
    out = out.to(device)

    om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
    return decode_segmap(om)

def replace_segmentation(imagen_segmentada, frame):
    # obtenemos los pixeles que no sean negros
    for idx_, val_ in enumerate(imagen_segmentada):
        for idx, val in enumerate(val_):
            for color in val:
                if color > 0:
                    # print(frame[idx_][idx])
                    imagen_segmentada[idx_][idx] = frame[idx_][idx]
                    pass
    return imagen_segmentada


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()
fcn.to(device)


cap = cv2.VideoCapture(0)
cap.set(3, 320)


while (True):
    # captura frame por frameq
    ret, frame = cap.read()


    # operaciones a cada frame
    img = Image.fromarray(frame)

    start = time.time()
    imagen_segmentada = segmentation(fcn, img, [len(frame), len(frame[0])], [len(frame), len(frame[0])], device)
    end = time.time()
    print('segmentacion semantica')
    print(end - start)

    start = time.time()
    imagen_sobrepuesta = replace_segmentation(imagen_segmentada, frame)
    end = time.time()
    print('sobreponer frame original')
    print(end - start)

    # resultados de la segmentacion
    cv2.imshow('frame', imagen_sobrepuesta)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

If you would like to keep only e.g. the first 4 classes, you could manipulate the last conv layer:

with torch.no_grad():
    fcn.classifier[4].weight = nn.Parameter(fcn.classifier[4].weight[:4])
    fcn.classifier[4].bias = nn.Parameter(fcn.classifier[4].bias[:4])

Thank you. It work’s, but the time execution isn´t reduced.

Exist some way to set specific classes e.g. 0,3,16?

Yeah, I wouldn’t expect to save so much time by slicing the last conv layer.
Did you time it?
If so, what was the difference?

To use specific class indices, just pass them as a list:

with torch.no_grad():
    fcn.classifier[4].weight = nn.Parameter(fcn.classifier[4].weight[[0, 3, 16]])
    fcn.classifier[4].bias = nn.Parameter(fcn.classifier[4].bias[[0, 3, 16]])