Hello everybody. I have a problem where a patient can have multiple images from a study and to determine if the patient has a disease or not we average all the predictions from his individual images. So to summarize:
- Iterate over all the patient’s images.
- Predict whether the disease is present on each image.
- Take the mean of all predictions as the final prediction.
- Compute loss function by comparing the mean of predictions with ground truth.
The code looks like this:
# Get predictions in the form of logits out = model(images) # Sigmoid predictions to get probabilities from logits pred = torch.sigmoid(out) # Average all predictions for a single patient pred = torch.mean(pred) loss = criterion(pred, label) loss.backward()
I do not know if back propagation is being computed on the mean of the output (called “pred” above) since I have not utilized the
requires_grad parameter on
torch.mean() nor on
Thanks in advance.