Hello everybody. I have a problem where a patient can have multiple images from a study and to determine if the patient has a disease or not we average all the predictions from his individual images. So to summarize:
- Iterate over all the patient’s images.
- Predict whether the disease is present on each image.
- Take the mean of all predictions as the final prediction.
- Compute loss function by comparing the mean of predictions with ground truth.
The code looks like this:
# Get predictions in the form of logits
out = model(images)
# Sigmoid predictions to get probabilities from logits
pred = torch.sigmoid(out)
# Average all predictions for a single patient
pred = torch.mean(pred)
loss = criterion(pred, label)
loss.backward()
I do not know if back propagation is being computed on the mean of the output (called “pred” above) since I have not utilized the requires_grad
parameter on torch.mean()
nor on torch.sigmoid()
.
Thanks in advance.