How to efficiently preserve most important DCT coefficients by PyTorch?

I am studying the discrete cosine transform (DCT) and JPEG compression technique. I do not know how to perform DCT for a 2-D image tensor and preserve only the most important coefficients?

I know that I can efficiently perform DCT by referencing https://github.com/zh217/torch-dct, here is an example:

import torch
import torch_dct as dct

x = torch.randn(8, 8)
x_dct = dct.dct_2d(x)
x_rec = dct.idct_2d(x_dct)
print(torch.abs(x - x_rec).mean())  # x == y within numerical tolerance

However, I want to preserve only the most important DCT coefficients at the upper left corner. In other words, how can I generate a Zig-Zag binary mask for a given image size and coefficient amount?

The following function is what I want:

>>> f(h=3, w=3, q=1)
[[1, 0, 0],
[0, 0, 0],
[0, 0, 0]]
>>> f(h=3, w=3, q=2)
[[1, 1, 0],
[0, 0, 0],
[0, 0, 0]]
>>> f(h=3, w=3, q=5)
[[1, 1, 0],
[1, 1, 0],
[1, 0, 0]]
>>> f(h=4, w=4, q=10)
[[1, 1, 1, 1],
[1, 1, 1, 0],
[1, 1, 0, 0],
[1, 0, 0, 0]]

(I will use it for a batch of image with the shape of [b, c, h, w].)

How can I implement the above function in an efficient form by using PyTorch? Or can I obtain a similar efficient function in another way?


P.S. I know I can implement the above function by some for-loops, but the obtained function may be too inefficient for me. Since I want to generate multiple masks in each network forwarding process, a real-time fast function implementation is appealing for me.

It seems that I resolved this problem by myself. The following code seems to be competent for meeting my requirements:

import torch, random
from functools import lru_cache

@lru_cache(maxsize=None)
def get_zigzag_ordered_indices(h=8, w=8):
    x, y = [], []
    x1, x2, y1, y2 = 0, 0, 0, 0
    flag = True
    while x2 < h or y1 < w:
        if flag:
            x = [*x, *range(x1, x2 - 1, -1)]
            y = [*y, *range(y1, y2 + 1)]
        else:
            x = [*x, *range(x2, x1 + 1)]
            y = [*y, *range(y2, y1 - 1, -1)]
        flag = not flag
        x1, y1 = (x1 + 1, 0) if (x1 < h - 1) else (h - 1, y1 + 1)
        x2, y2 = (0, y2 + 1) if (y2 < w - 1) else (x2 + 1, w - 1)
    return x, y

@lru_cache(maxsize=None)
def get_zigzag_truncated_indices(h=8, w=8, q=6):
    if random.randint(0, 1):
        x, y = get_zigzag_ordered_indices(h, w)
    else:
        y, x = get_zigzag_ordered_indices(w, h)
    return x[:q], y[:q]

@lru_cache(maxsize=None)
def get_mask_DCT(h=8, w=8, q=6, device='cuda'):
    mask_DCT = torch.zeros(1, 1, h, w, device=device)
    DCT_x, DCT_y = get_zigzag_truncated_indices(h, w, q)
    mask_DCT[:, :, DCT_x, DCT_y] = 1
    return mask_DCT

print(get_mask_DCT())

Is there a more efficient PyTorch implementation?