Hi, I put together a FasterRCNN trainer that can load a directory of my own images and accompanying xml files that contain the tags and bounding boxes for each image. I used this page: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html and others to figure out how to do it.
It appears to be working, i.e. it runs and seems to tune the pretrained model loaded with torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
but my results continue to be pretty terrible, even with over 800 training images. It does detect my objects correctly in some images, but in others that would seem to be easy it is completely lost. What I’m not completely sure about is how I’m performing the training and evaluation for each epoch.
For example, with the simpler image classification models, you do something like the train_model() function from this tutorial page: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html. If you’ll notice in that function, there are these lines:
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
along with the tracking of running loss, running correct guesses, and epoch loss, and if the epoch loss is better for a particular epoch, the best model weights are copied into the model: best_model_wts = copy.deepcopy(model.state_dict())
and then the best model weights are loaded before the model is returned from training and evaluation. This method makes it seem like it’s getting the best possible trained model by the end.
However, with training FasterRCNN I cannot figure out how to use the criterion() method at all, how to get a torch.max() in there, how to add up the running correct guesses, or how to do the epoch losses. Here is my train_one_epoch() method, cobbled together from various tutorials, stack overflow questions, and github projects:
def train_one_epoch(model, criterion, optimizer, dataloader, epoch):
model.train()
print("Epoch: " + str(epoch) + "/" + str(number_of_epochs))
lr_scheduler = None
len_dataloader = len(dataloader)
if epoch == 0:
warmup_factor = 1. / 1000
warmup_iters = min(1000, len(dataloader) - 1)
lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
i = 0
with torch.set_grad_enabled(True):
for images, targets in dataloader:
i += 1
optimizer.zero_grad()
images = list(img.to(device) for img in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
# if loss is infinite we got problems
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(losses)
exit()
losses.backward()
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
if (i % 6 == 0):
print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')
and here is my evaluate() method:
def evaluate(model, dataloader):
model.eval()
running_loss = 0.0
loss_value = 0.0
for images, targets in dataloader:
images = list(img.to(device) for img in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
with torch.no_grad():
loss_dict = model(images)
# this returned object from the model:
# len is 4 (so index here), which is probably because of the size of the batch
# loss_dict[index]['boxes']
# loss_dict[index]['labels']
# loss_dict[index]['scores']
for x in range(image_batch_size):
loss_value += sum(loss for loss in loss_dict[x]['scores'])
running_loss += loss_value
return running_loss
As you can see, all evaluate does is change the model into eval() mode and then do a run on a smaller set (than the training set) of loaded evaluation images. Nothing is done with that returned running_loss from evaluate() except that it is printed out to the console. For the training to be truly effective, does the criterion() method need to be called during training, and do the best model weights and the running missed guesses and losses need to be calculated? Or are my results just terrible because it’s possible my training set isn’t good enough and I need to find better images of objects? Thanks for any help!