[solved] Discrete cosine transform implementation in pytorch?

UPDATE: I figured it out. This link explains how to incorporate scipy functions into the computational graph. You can use scipy.fftpack.dct http://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html

Hi everyone,
So I want to backprop the error gradient through the inverse discrete cosine transform.

I tried implementing a naive versi
on of the DCT but it was reallyyyyy slow.

Does anyone have any suggestions for writing a fast, backpropable implementation of the inverse discrete cosine transform?

I can paste my code here. The only reason I didn’t is because it is very messy and probably would not help clarify the situation.


For anyone coming here from Google search: I have implemented DCT for pytorch in terms of the built-in FFT, so that it works on CPU and GPU, through back propagation:


I just found this on Github, many thanks :+1:t2: