Check gradient of a particular data sample during training

Hi, in my training, I found that batches containing some specific data samples constantly getting very large gradients. However, since I have too many training samples, I cannot figure out which samples in the batch are getting these bad gradients.

Is there any way in PyTorch to print out the norm of an intermediate feature for each data sample in a batch? i.e. if I have:
data >> fc_layer1 >> fc1_feature >> fc_layer2 >> fc2_feature >> softmaxloss_layer >> loss, I want to print norm( d(loss)/d(fc1_feature) ) for every data sample.

Thank you very much for the help!

Hi @Paralysis

You can register hooks to peek into the intermediate computations. There’s an example on the pytorch docs, but I’ve pasted the example w/ some slight mods into this notebook gist:

You will have to write a function to recurse onto modules with submodules in order to register all submodules, but that should be pretty straightforward.

1 Like