# How to efficiently preserve most important DCT coefficients by PyTorch?

I am studying the discrete cosine transform (DCT) and JPEG compression technique. I do not know how to perform DCT for a 2-D image tensor and preserve only the most important coefficients?

I know that I can efficiently perform DCT by referencing https://github.com/zh217/torch-dct, here is an example:

``````import torch
import torch_dct as dct

x = torch.randn(8, 8)
x_dct = dct.dct_2d(x)
x_rec = dct.idct_2d(x_dct)
print(torch.abs(x - x_rec).mean())  # x == y within numerical tolerance
``````

However, I want to preserve only the most important DCT coefficients at the upper left corner. In other words, how can I generate a Zig-Zag binary mask for a given image size and coefficient amount?

The following function is what I want:

``````>>> f(h=3, w=3, q=1)
[[1, 0, 0],
[0, 0, 0],
[0, 0, 0]]
>>> f(h=3, w=3, q=2)
[[1, 1, 0],
[0, 0, 0],
[0, 0, 0]]
>>> f(h=3, w=3, q=5)
[[1, 1, 0],
[1, 1, 0],
[1, 0, 0]]
>>> f(h=4, w=4, q=10)
[[1, 1, 1, 1],
[1, 1, 1, 0],
[1, 1, 0, 0],
[1, 0, 0, 0]]
``````

(I will use it for a batch of image with the shape of `[b, c, h, w]`.)

How can I implement the above function in an efficient form by using PyTorch? Or can I obtain a similar efficient function in another way?

P.S. I know I can implement the above function by some for-loops, but the obtained function may be too inefficient for me. Since I want to generate multiple masks in each network forwarding process, a real-time fast function implementation is appealing for me.

It seems that I resolved this problem by myself. The following code seems to be competent for meeting my requirements:

``````import torch, random
from functools import lru_cache

@lru_cache(maxsize=None)
def get_zigzag_ordered_indices(h=8, w=8):
x, y = [], []
x1, x2, y1, y2 = 0, 0, 0, 0
flag = True
while x2 < h or y1 < w:
if flag:
x = [*x, *range(x1, x2 - 1, -1)]
y = [*y, *range(y1, y2 + 1)]
else:
x = [*x, *range(x2, x1 + 1)]
y = [*y, *range(y2, y1 - 1, -1)]
flag = not flag
x1, y1 = (x1 + 1, 0) if (x1 < h - 1) else (h - 1, y1 + 1)
x2, y2 = (0, y2 + 1) if (y2 < w - 1) else (x2 + 1, w - 1)
return x, y

@lru_cache(maxsize=None)
def get_zigzag_truncated_indices(h=8, w=8, q=6):
if random.randint(0, 1):
x, y = get_zigzag_ordered_indices(h, w)
else:
y, x = get_zigzag_ordered_indices(w, h)
return x[:q], y[:q]

@lru_cache(maxsize=None)
def get_mask_DCT(h=8, w=8, q=6, device='cuda'):
mask_DCT = torch.zeros(1, 1, h, w, device=device)
DCT_x, DCT_y = get_zigzag_truncated_indices(h, w, q)
mask_DCT[:, :, DCT_x, DCT_y] = 1