Is it possible to only compute the loss of some model with respect to certain batch elements?
For concrete, let’s say we have some model
, with loss = model(x)
, where x
has size [batch_size, dim]
. There is a boolean mask mask
with entries which are True at the batch indices where the tensor should be evaluated and False otherwise. Is there a way to allow the model to only compute the gradient for these entries?
While it is possible to compute the loss on unmasked elements only, loss = model(x[mask])
, I would like to make use of the loss of the masked elements in x
, just not allow these elements to affect the model.