Setting gradient to zero manually?


for some predictions of my model, I want the optimizer not to consider these predictions in updating the parameters. Therefore, I thought of setting the gradient of the loss corresponding to these predictions to zero.

def loss_fn(y_pred,y): 
     loss = some_function(y_pred,y)
     if y_pred==specific_prediction:
            loss.grad = None
     return loss

Would that be conceptually correct?


Best, JZ

You can simply use a binary mask to filter out these losses of unexpected predictions.

loss = some_loss_function(y_pred, y)
loss = loss * binary_mask
loss = loss.sum() # or alternatively loss = loss.mean()

Where binary_mask[i] = 0 for unexpected predictions and binary_mask[i] = otherwise.