I’m not sure I understand how to reduce DWT to convolutions as you propose.
Edit: Are you proposing to basically construct nlogn conv filters? since the higher frequency wavelets are just translated and mostly 0 everywhere I suppose it’s possible. Do you have an example? Not sure how I would stitch it all up.
The scattering transform is indeed interesting but not what I’m immediately looking for, will check it out though thanks. Seems to have high potential, too bad you can’t backprop through the fast implementation though.
I recently implemented a wavelet filter bank in PyTorch. Although my focus here was to write a fast temporal convolution library for wavelets, this might be of interest to you: https://github.com/tomrunia/PyTorchWavelets
The repo has tests checking that gradients pass nicely through the dual tree DWT, but I am yet to write tests to check the gradients for the DWT. I would wager they work nicely, but do need to confirm. The perfect reconstruction works, as well as getting the same wavelet coefficients as pywt.
Hi @tom@fbcotter.
One thing that appears to be different between both of your implementations is the requires_grad parameter. As far as I can understand @tom’s solution does not explicitly state that the convolution layer parameters have requires_grad = False and thus they would get updated at backprop (as would any kernel of a conv2d or transposeconv2d layer). Similarly in @fbcotter’s repo, the parameters are explicitly marked with requires_grad = False.
For my implementation I need a 1D wavelet based loss, so I would require the parameters to have requires_grad = False, it seems like it might be slightly easier for me to adapt tom’s version, so I wanted to make sure I understand this before going ahead and using the code.
So I finally removed all the long deprecated Variable bits.
My notebook makes no assumption about whether you pass in weights which require grads or not - if you do, it will happily backpropagate into them.
I did use a similar approach for audio (1d).
Thanks for the update @tom So essentially, to make sure that the DWT is fixed =, I would simply set dec_hi.requires_grad = False (and same for the others off course). And then everything should work.
Perhaps I’ll write this as a pytorch module so it can be easily integrated into a network.
Not requiring grad is the default, so no need to set things.
My personal philosophy here is to keep stateless things as functions, but having it in a module is fine, too.