Backprop through Discrete Wavelet Transform (DWT) on GPU


#1

It there an efficient way to perform this operation? 2D.

The naive approach is to either loop through the rows and columns or construct the matrix and multiply.

Oddly I was unable to find an example of how to do this for either theano or tensorflow.


(Thomas V) #2

You can use conv layers to have single levels (using coeffs eg from pywavelets and setting requires_grad to False.
Also check out

that seems to be the most waveletty neural network thing I have recently seen.

I think I saw Edouard around here, but I forgot his username.

Best regards

Thomas


#3

I’m not sure I understand how to reduce DWT to convolutions as you propose.
Edit: Are you proposing to basically construct nlogn conv filters? since the higher frequency wavelets are just translated and mostly 0 everywhere I suppose it’s possible. Do you have an example? Not sure how I would stitch it all up.

The scattering transform is indeed interesting but not what I’m immediately looking for, will check it out though thanks. Seems to have high potential, too bad you can’t backprop through the fast implementation though.


(Thomas V) #4

As talk is cheap, I wrote up a quick notebook on using 2D Wavelet transformations with PyTorch and hope it is useful for you.

Best regards

Thomas


#5

It certainly is, you seem like you have gone through a bit of effort.
Thanks.


(Tom) #6

I recently implemented a wavelet filter bank in PyTorch. Although my focus here was to write a fast temporal convolution library for wavelets, this might be of interest to you: https://github.com/tomrunia/PyTorchWavelets


(Fergal Cotter) #7

@tom Great work on the notebook. I already had a repo to do a dual tree DWT in pytorch (https://en.wikipedia.org/wiki/Complex_wavelet_transform), and inspired by your code, I have now added support for the DWT and Inverse DWT in 2 dimensions. You can check it out at
https://github.com/fbcotter/pytorch_wavelets (@verified.human sorry for the similar name!)

The repo has tests checking that gradients pass nicely through the dual tree DWT, but I am yet to write tests to check the gradients for the DWT. I would wager they work nicely, but do need to confirm. The perfect reconstruction works, as well as getting the same wavelet coefficients as pywt.


(Thomas V) #8

Thanks for sharing, @fbcotter!

It’s not terribly efficient, but wouldn’t the tracing work for backpropagation until you have an explicit backward?