For my project I have to record the gradient associated to each single sample of a generic batch (not the mean/sum given by the propagation of the whole group). I am working on a VGG16 architecture with nn.BatchNorm2d
layers.
If I split the batch in single elements to record gradients time by time and finally sum them togheter before the optimizer.step()
I cannot use nn.BatchNorm2d
(single element batch). Is there a way to keep the nn.BatchNorm2d
layer and still forward samples one by one?
(for example making a first forwarding with the whole batch, fixing by it (somehow) the Batch Norm state associated to the batch and then repeat the propagation sample by sample)