I wrote the following class to perform instance segmentation and return the masks of a given class.

The code seems to be running randomly and it’s not deterministic.

The labels printed (as well as the number of labels) change at every execution even if I am running the code on the same input image containing a single person.

Is there a problem in how I load the weights? The code is not printing any warning nor exception.

Note that I am running the code on the CPU.

```
import numpy as np
import torch
from torch import Tensor
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
import torchvision.transforms as T
import PIL
from PIL import Image
class RetinaNet:
def __init__(self, weights: RetinaNet_ResNet50_FPN_V2_Weights = RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1):
# Load the pre-trained DeepLabV3 model
self.weights = weights
self.model = retinanet_resnet50_fpn_v2(
pretrained=RetinaNet_ResNet50_FPN_V2_Weights
)
self.model.eval()
# Check if a GPU is available and if not, use a CPU
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
# Define the transformation
self.transform = T.Compose([
T.ToTensor(),
])
def infer_on_image(self, image: PIL.Image.Image, label: str) -> Tensor:
# Transform image
input_tensor = self.transform(image)
input_tensor = input_tensor.unsqueeze(0)
input_tensor.to(self.device)
# Run model
with torch.no_grad():
predictions = self.model(input_tensor)
# Post-processing to create masks for requested label
label_index = self.get_label_index(label)
boxes = predictions[0]['boxes'][predictions[0]['labels'] == label_index]
print('labels', predictions[0]['labels']) # random output
masks = torch.zeros((len(boxes), input_tensor.shape[1], input_tensor.shape[2]), dtype=torch.uint8)
for i, box in enumerate(boxes.cpu().numpy()):
x1, y1, x2, y2 = map(int, box)
masks[i, y1:y2, x1:x2] = 1
return masks
def get_label_index(self,label: str) -> int:
return self.weights.value.meta['categories'].index(label)
def get_label(self, label_index: int) -> str:
return self.weights.value.meta['categories'][label_index]
@staticmethod
def load_image(file_path: str) -> PIL.Image.Image:
return Image.open(file_path).convert("RGB")
if __name__ == '__main__':
from matplotlib import pyplot as plt
image_path = 'person.jpg'
# Run inference
retinanet = RetinaNet()
masks = retinanet.infer_on_image(
image=retinanet.load_image(image_path),
label='person'
)
# Plot image
plt.imshow(retinanet.load_image(image_path))
plt.show()
# PLot mask
for i, mask in enumerate(masks):
mask = mask.unsqueeze(2)
plt.title(f'mask {i}')
plt.imshow(mask)
plt.show()
```