Hello:
I’m starting in semantic segmentation with fcn_resnet101 model. With my webcam I stream frame by frame with recognized clases but it is so slow (around 1.94 seconds per frame). I think if removing unused classes this could be a little bit fast.
The fcn_resnet101 model has 21 classes but I jus need around three or four classes from that model.
Is there some way to remove the unused classes or just ommiting them while the image analisys is running?
Here is my code
import sys
import torch
from PIL import Image
import torchvision.transforms as T
import numpy as np
from torchvision import models
import cv2
import time
# Define the helper function
def decode_segmap(image, nc=21):
label_colors = np.array([(0, 0, 0), # 0=background
# 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
(128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
# 6=bus, 7=car, 8=cat, 9=chair, 10=cow
(0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
# 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
(192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
# 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
(0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, nc):
idx = image == l
r[idx] = label_colors[l, 0]
g[idx] = label_colors[l, 1]
b[idx] = label_colors[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
def segmentation(modelo, img, resize, center_crop, device):
trf = T.Compose([T.Resize(resize),
T.CenterCrop(center_crop),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
inp = trf(img).unsqueeze(0)
inp = inp.to(device)
out = modelo(inp)['out']
out = out.to(device)
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
return decode_segmap(om)
def replace_segmentation(imagen_segmentada, frame):
# obtenemos los pixeles que no sean negros
for idx_, val_ in enumerate(imagen_segmentada):
for idx, val in enumerate(val_):
for color in val:
if color > 0:
# print(frame[idx_][idx])
imagen_segmentada[idx_][idx] = frame[idx_][idx]
pass
return imagen_segmentada
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()
fcn.to(device)
cap = cv2.VideoCapture(0)
cap.set(3, 320)
while (True):
# captura frame por frameq
ret, frame = cap.read()
# operaciones a cada frame
img = Image.fromarray(frame)
start = time.time()
imagen_segmentada = segmentation(fcn, img, [len(frame), len(frame[0])], [len(frame), len(frame[0])], device)
end = time.time()
print('segmentacion semantica')
print(end - start)
start = time.time()
imagen_sobrepuesta = replace_segmentation(imagen_segmentada, frame)
end = time.time()
print('sobreponer frame original')
print(end - start)
# resultados de la segmentacion
cv2.imshow('frame', imagen_sobrepuesta)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()