I get the intuition behind LRP and I would like to implement in PyTorch. However, I’m not that familiar with the PyTorch internals to know how to best get started.
The website links to a LRP wrapper for Tensorflow. When looking at the code, it seems one has to wrap each layer to add some LRP specific code snippet. I assume the same would be required for PyTorch.
Does anyone already implemented LRP in Pytorch? Or does anyone has some opinions how to tackle this?
So what I did was to implement autograd.Functions that use the regular forward but implement the LRP rules for backward. Then I mirrored all nn.Modules my network uses. Worked for LSTMs (ULMFiT) and ResNets and some others for me.
Note that “unconstrained” LRP for ResNets does have continuity issues with the residual connections, here is a discussion involving some coauthors of some papers.