This is my first time using PyTorch (sorry if I picked the wrong topic listing), and I’ve ran into a strange issue when programming the validation phase of my model. Here’s the code snippet:
with torch.no_grad():
for images, targets in validation_dataloader:
# Forward pass
loss_dicts = model(images, targets)
batch_losses = [sum(loss_dict.values()) for loss_dict in loss_dicts]
total_loss = sum(batch_losses) / len(batch_losses)
total_val_loss += total_loss
model(images, targets) is returning a list of dictionaries, each dictionary in the list corresponds to an image in the batch. This is not the understanding I had of the function, the understanding of the function I had is that it would return information about loss. Was what I read wrong, or is there anyway I can fix this?
If needed, I’d be more than happy to provide any other context.