I would like to compute a gradient penalty w.r.t. to model parameters that are related to specific input variables in the first layer. To do that I implemented a custom loss function such as:
def custom_loss(model, output, target, penalization_coef, mode:str):
assert(mode=="train" or mode=="eval")
main_loss = torch.mean((output - target)**2) # some loss function for the predictive task
# compute penalisation norm
if mode == "eval":
penalisation_norm = 0
else:
grad_params = main_loss.backward(retain_graph=True)
param = next(model.parameters()) # get first layer parameters
penalisation_norm = penalization_coef * param.grad[:, [input_indexes]].abs().sum() # compute a L1 norm of the gradient associated with specific input variable listed in input_indexes
for param in model.parameters(): # Reset gradients
param.grad = None
aggregated_loss = main_loss + penalisation_norm
...
return main_loss, aggregated_loss
Problem is that it doesn’t seem to make any difference with the training without penalisation… What i’m doing wrong here?