Hi,
I have been trying to implement the batch renormalization. Implementing in python by creating a custom Module results in much higher memory usage and time consumption. So I created a BatchReNorm function in cuda, following the same pattern as BatchNorm in THCUNN.
I am able to successfully build and get the updates and the clipping working perfectly, however, I am not sure regarding the stop_gradient mentioned in the batch renorm paper.
I create the parameters r
and d
in the BatchReNormalizationUpdateOutput_kernel
as below:
Acctype r = 0;
Acctype d = 0;
I then clamp the values using THCNumerics<Acctype>::lt
and THCNumerics<Acctype>::gt
I do not pass them to the BatchReNormalizationBackward_kernel
Does that ensure that the gradients don’t flow through them?
Any insight would be really helpful.
Regards
Nabarun