Custom flatten function using GPU acceleration?

Hello everyone,

I’m currently working on a project where I need to flatten images while preserving spatial locality, and I’m exploring the use of a Hilbert curve for this purpose. I’ve been inspired by 3blue1brown’s video on the topic (https://www.youtube.com/watch?v=3s7h2MHQtxc) and would like to implement a custom flatten function that maps an image into a Hilbert curve.

While I understand the concept and have experimented with some basic CPU implementations, I’m now interested in accelerating this process using PyTorch and GPU support. However, I’m not entirely sure how to optimize this custom operation for GPU execution, particularly with respect to leveraging PyTorch’s capabilities.

Right now, the code (hereunder) I’m playing with is very slow locally but also on Google Colab.
Therefore, would anyone be able to point me in the right direction on how to implement such a custom operation with efficient GPU acceleration in PyTorch? Any references, tips, or code examples would be greatly appreciated!

Find here the code I’m playing with:

import torch
import numpy as np
import math

# Function to compute sign of a value
def sgn(x):
    return -1 if x < 0 else (1 if x > 0 else 0)

# Recursive generate2d function for Gilbert curve
def generate2d(x: int, y: int, ax: int, ay: int, bx: int, by: int):
    """Recursive generation of 2D coordinates using the Gilbert space-filling curve."""

    def move_point(px: int, py: int, dx: int, dy: int):
        """Move point (px, py) in direction (dx, dy)."""
        return px + dx, py + dy

    # Width and height of the grid to fill
    w = abs(ax + ay)
    h = abs(bx + by)

    # Direction vectors
    dax, day = sgn(ax), sgn(ay)  # Major direction
    dbx, dby = sgn(bx), sgn(by)  # Orthogonal direction

    # Handle trivial row or column fills
    if h == 1 or w == 1:
        result = []
        if h == 1:
            for _ in range(w):
                result.append((x, y))
                x, y = move_point(x, y, dax, day)
        elif w == 1:
            for _ in range(h):
                result.append((x, y))
                x, y = move_point(x, y, dbx, dby)
        return result

    # Halve the movement vectors
    ax2, ay2 = ax // 2, ay // 2
    bx2, by2 = bx // 2, by // 2

    w2 = abs(ax2 + ay2)
    h2 = abs(bx2 + by2)

    result = []
    if 2 * w > 3 * h:
        if w2 % 2 and w > 2:
            ax2, ay2 = move_point(ax2, ay2, dax, day)

        result.extend(generate2d(x, y, ax2, ay2, bx, by))
        result.extend(generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by))

    else:
        if h2 % 2 and h > 2:
            bx2, by2 = move_point(bx2, by2, dbx, dby)

        result.extend(generate2d(x, y, bx2, by2, ax2, ay2))
        result.extend(generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2))
        result.extend(generate2d(x + (ax - dax) + (bx2 - dbx),
                                 y + (ay - day) + (by2 - dby),
                                 -bx2, -by2, -(ax - ax2), -(ay - ay2)))

    return result

# Top-level gilbert2d function
def gilbert2d(width, height):
    if width >= height:
        return generate2d(0, 0, width, 0, 0, height)
    else:
        return generate2d(0, 0, 0, height, width, 0)

# Function to reshape tensor following the Gilbert curve
def reshape_via_gilbert(tensor, width=None, height=None):
    flattened_tensor = tensor.flatten()
    num_elements = flattened_tensor.numel()

    if width is None or height is None:
        if width is None and height is not None:
            # Automatically calculate width
            width = (num_elements + height - 1) // height
        if height is None and width is not None:
            # Automatically calculate height
            height = (num_elements + width - 1) // width
        if height is None and width is None:
            # Automatically calculate width and height
            height = height or math.isqrt(num_elements)
            width = width or (num_elements + height - 1) // height

    # Create an empty tensor to store the reshaped values
    reshaped_tensor = torch.zeros((height, width), dtype=tensor.dtype, device=tensor.device)

    # Get the Gilbert curve path
    gilbert_path = gilbert2d(width, height)

    # Map the flattened tensor values to the Gilbert path
    for idx, (x, y) in enumerate(gilbert_path):
        if idx < num_elements:
            reshaped_tensor[y, x] = flattened_tensor[idx]
        else:
            break

    return reshaped_tensor

# Example usage
tensor = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16],
    [17, 18, 19, 20],
    [21, 22, 23, 24],
    [25, 26, 27, 28],
    [29, 30, 31, 32],
    [33, 34, 35, 36],
])

reshaped_tensor = reshape_via_gilbert(tensor)
print(reshaped_tensor)

reshaped_tensor = reshape_via_gilbert(tensor, width=5)
print(reshaped_tensor)

reshaped_tensor = reshape_via_gilbert(tensor, height=4)
print(reshaped_tensor)

reshaped_tensor = reshape_via_gilbert(tensor, width=8, height=8)
print(reshaped_tensor)
tensor([[ 1,  4,  5, 32, 33, 36],
        [ 2,  3,  6, 31, 34, 35],
        [11, 10,  7, 30, 27, 26],
        [12,  9,  8, 29, 28, 25],
        [13, 16, 17, 20, 21, 24],
        [14, 15, 18, 19, 22, 23]])
tensor([[ 1,  4,  5,  8,  9],
        [ 2,  3,  6,  7, 10],
        [19, 18, 15, 14, 11],
        [20, 17, 16, 13, 12],
        [21, 24, 25, 28, 29],
        [22, 23, 26, 27, 30],
        [ 0,  0, 35, 34, 31],
        [ 0,  0, 36, 33, 32]])
tensor([[ 1,  2, 15, 16, 17, 18, 34, 35, 36],
        [ 4,  3, 14, 13, 20, 19, 33, 32, 31],
        [ 5,  8,  9, 12, 21, 24, 25, 30, 29],
        [ 6,  7, 10, 11, 22, 23, 26, 27, 28]])
tensor([[ 1,  4,  5,  6,  0,  0,  0,  0],
        [ 2,  3,  8,  7,  0,  0,  0,  0],
        [15, 14,  9, 10,  0,  0,  0,  0],
        [16, 13, 12, 11,  0,  0,  0,  0],
        [17, 18, 31, 32, 33, 34,  0,  0],
        [20, 19, 30, 29, 36, 35,  0,  0],
        [21, 24, 25, 28,  0,  0,  0,  0],
        [22, 23, 26, 27,  0,  0,  0,  0]])

Generalized Hilbert algorithm credits to Jakub Červený @ GitHub - jakubcerveny/gilbert: Space-filling curve for rectangular domains or arbitrary size.

Thank you in advance for your help.