Ensembling Predictions for combined Uncertainty-based Loss and Dice

I’m trying to train multiple models in parallel, combine the predictions to produce an uncertainty estimate, then measure the quality of this estimate with Expected Calibration Error.

For each model, I want to then use the loss as the dice of its prediction + the ECE of the ensembled models’ predictions.

I’m not sure how the backprop process will work with this. I understand that you can simply combine losses then call .backward(), however if I combine multiple predictions, then run ECE on these predictions, is using the same calculate ECE fine to use in the sum with the specific model’s prediction?

To put it in psuedocode:

for epoch in epochs:
  preds = predict_all_models(models)
  stacked_pred = stack_preds(preds)
  uncertainty = calc_entropy_uncertainty(stacked_pred)
  ece = calc_ece(stacked_pred, uncertainty, truth)
  for model in models:
    dice = dice(preds[model])
    combined_loss = dice + ece

Further, I’m not sure how autograd works - is there anything special I need to do to make backward() work, or can I just calculate the ECE using PyTorch functions and it’ll magically work?

As long as ECE is implemented with differentiable torch ops then yes the gradient should propagate properly through it. As long as you don’t do weird things like create new tensors, convert to numpy and back, etc, it should be fine.

The updated gradients will propagate properly through all models, despite the predictions being stacked to produce the ECE?

If the predictions are continuous & you used torch.stack (rather than constructing a new tensor yourself), then yes.

I’m now getting a RuntimeError: Trying to backward through the graph a second time.

Am I right in thinking that because stacked_pred and uncertainty come from the same preds variable, when I call backward on the ECE, it causes the gradients to be propagated from both stacked_pred and uncertainty to the original predictions.

Therefore, since they’re both combined to produce ECE, am I okay to call .detach() on one of the values, solving the problem? Or would this cause the gradients from ECE to not properly backprop?

EDIT: Maybe not… maybe instead because I’m combining the dice and ECE which both come from the same prediction?

It’s probably worth mentioning that the dice in my original psuedocode is the dice for an individual model’s prediction.

EDIT 2: Thinking about it a bit more, I should probably do

for model in models:
    dice = dice(preds[model])

otherwise I’ll be calling .backward() on the same loss function multiple times.