Greetings! Looking to get some help with the following task.

I’m training a model that can reconstruct occluded parts of an original image. The correct reconstruction can be multimodal in pixel color space, and the different correct representations can even be very very far apart (e.g. suppose the thing to be reconstructed is a ying-yang sign in black and white, and in the data half the time the sign has white on the left and black on the right, and half the time it has black on the left and white on the right). I’d like to allow my model to make N different predictions and for the loss function to be the minimum MSE distance (minimum across the N predictions) between the prediction and the target. That way the model can say something like “I think the missing piece is a ying-yang symbol either with black on the left and white on the right, or vice-versa” and be correctly rewarded for the most accurate of the predictions.

Is this kind of minimum loss function possible in Pytorch? If so, has anyone used it before and can share best practices? I can imagine that there are issues related to the fact that it’s kinked around the minimum threshold and the first derivative is not easily defined for autograd.

Thank you!