Hello,
My neuronal network is predicting animals with a custom dataset in coco format.
My net is working well as far as painting bounding boxes. But the models output[‘scores’] is always 1.
Does anyone know why / how to fix it?
def collate_fn(batch):
return tuple(zip(*batch))
def get_transform():
custom_transforms = []
custom_transforms.append(torchvision.transforms.ToTensor())
return torchvision.transforms.Compose(custom_transforms)
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
def runTraining(epochs = 1):
print("creating dataset")
epochs = int(input("How many train epochs? "))
train_data = dsetCoco(
root="/media/john/Volume/pythonProjects/linux/wildcamai/testdata/images",
annFile="/media/john/Volume/pythonProjects/linux/wildcamai/testdata/annotations/instances_default.json",
transforms=get_transform()
)
train_data_loader = torch.utils.data.DataLoader(
train_data,
batch_size=1,
shuffle=True,
collate_fn=collate_fn
)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(torch.device)
num_classes = 18
model = get_model_instance_segmentation(num_classes)
if os.path.exists("./model.pth"):
print("Previous trained model found!")
model.load_state_dict(torch.load("./model.pth"))
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
len_dataloader = len(train_data_loader)
for epoch in range(epochs):
model.train()
epoch_losses = []
i = 0
for imgs, annotations in train_data_loader:
i += 1
imgs = list(img.to(device) for img in imgs)
annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
try:
# dont know how to handle zero tensor or whats causing this
loss_dict = model(imgs, annotations)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')
epoch_losses.append(losses)
except:
print("Zero tensor.")
total_losses.append(epoch_losses)
torch.save(model.state_dict(), "model.pth")
def runPrediction():
test_data = datasets.coco.CocoDetection(
root="/media/john/Volume/pythonProjects/linux/wildcamai/traindata/images",
annFile="/media/john/Volume/pythonProjects/linux/wildcamai/traindata/annotations/instances_default.json",
transform=ToTensor()
)
batch_size = 1
num_workers = 4
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)
model = get_model_instance_segmentation(18)
if input("Should out of the Box trained Model be used? (y/n) (Mapping incorrect)") == "n":
if os.path.exists("./model.pth"):
print("Previous trained model found!")
model = get_model_instance_segmentation(18)
model.load_state_dict(torch.load("./model.pth"))
else:
print("No pretrained Model found. Either train or use out of the box!")
else:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
with torch.no_grad():
for img, label in test_dataloader.dataset:
model.eval()
img = img.unsqueeze(0)
#print(img.size())
#print(img)
output = model(img)
#here handling of boxes
plt.imshow(img[0].permute(1, 2, 0))
for i in range(len(output[0]['boxes'])):
if output[0]['scores'][i] > 0.25:
print(output[0]['boxes'][i])
x1, y1, x2, y2 = output[0]['boxes'][i]
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1])
lbl = Label.getlabel(output[0]['labels'][i].item())
scr = round(output[0]['scores'][i].item(), 2)
scr *= 100
concat = "Id:"+ str(output[0]['labels'][i].item()) + " / " + lbl + " " + str(scr) + "%"
#concat = str(scr) + "%"
plt.text(x1, y1, concat, color="RED")
plt.show()`Preformatted text`