I am studying the discrete cosine transform (DCT) and JPEG compression technique. I do not know how to perform DCT for a 2-D image tensor and preserve only the most important coefficients?
I know that I can efficiently perform DCT by referencing https://github.com/zh217/torch-dct, here is an example:
import torch import torch_dct as dct x = torch.randn(8, 8) x_dct = dct.dct_2d(x) x_rec = dct.idct_2d(x_dct) print(torch.abs(x - x_rec).mean()) # x == y within numerical tolerance
However, I want to preserve only the most important DCT coefficients at the upper left corner. In other words, how can I generate a Zig-Zag binary mask for a given image size and coefficient amount?
The following function is what I want:
>>> f(h=3, w=3, q=1) [[1, 0, 0], [0, 0, 0], [0, 0, 0]] >>> f(h=3, w=3, q=2) [[1, 1, 0], [0, 0, 0], [0, 0, 0]] >>> f(h=3, w=3, q=5) [[1, 1, 0], [1, 1, 0], [1, 0, 0]] >>> f(h=4, w=4, q=10) [[1, 1, 1, 1], [1, 1, 1, 0], [1, 1, 0, 0], [1, 0, 0, 0]]
(I will use it for a batch of image with the shape of
[b, c, h, w].)
How can I implement the above function in an efficient form by using PyTorch? Or can I obtain a similar efficient function in another way?
P.S. I know I can implement the above function by some for-loops, but the obtained function may be too inefficient for me. Since I want to generate multiple masks in each network forwarding process, a real-time fast function implementation is appealing for me.