How to catch exception for single elements of a minibatch in forward function of a model

I’m trying to train a deep model using PyTorch. One of the layers is a differentiable optimization layer which I implemented using cvxpylayers (GitHub - cvxgrp/cvxpylayers: Differentiable convex optimization layers). Depending on the input to the optimization layer, the optimization might be infeasible (this is the wanted behavior for some inputs). In case of infeasibility the optimization layer will raise an exception. In that case I would like to circumvent the optimization layer using a different layer that will not return an exception. I have implemented this using a try/except block in the forward function of the model.

My problem is, that when training with minibatches, an exception in the optimization layer for any one training example will lead to the optimization layer being circumvented for all training examples in the minibatch.

Is there any way to basically do an element-wise try/except inside the forward function of a PyTorch model that allows for batch-mode training in my situation?

Thanks a lot for your help!