Color quantization?

Is there anything similar to PIL’s Image.quantize() in PyTorch?

Quantize reduces the number of colors in an image, with an option to a specific set of target colors. I am sure a similar function could be created for tensors but would the gradient be maintained? If yes, any tips on implementing it in an efficient way?

For now I created the following methods:

def distance(color_a, color_b):
    differences = torch.stack([(c_a - c_b) ** 2 for c_a, c_b in zip(color_a, color_b)])
    return torch.sum(differences, dim=-2)**0.5 # Euclidian distance
    
def quantize(image, palette):
    # Assume RGB
    flat_img = image.view(1, 3,-1) # [C, H, W] -> [1, C, H*W]
    img_per_palette_color = torch.cat(len(palette)*[flat_img]) # [1, C, H*W] -> [n_colors, C, H*W]
    distance_per_pixel = distance(img_per_palette_color, palette.unsqueeze(-1)) # [n_colors, C, H*W] -> [n_colors, H*W]
    color_indices = torch.argmin(distance_per_pixel, dim=0) # [n_colors, H*W] -> [H*W]
    new_colors = palette[color_indices].T # [H*W] -> [C, H*W]
    return new_colors.view(image.shape) # [C, H*W] -> [C, H, W]

colors = torch.tensor([
    [0, 75, 135],
    [255, 205, 0]
])
colors_normalized = colors/255

original = TF.to_tensor(Image.open(r"[PATH]").convert("RGB"))
display(TF.to_pil_image(original))

quantized = quantize(original, colors_normalized)
display(TF.to_pil_image(quantized))

Which outputs:

(1st one is the original, 2nd one is quantized)

I will investigate if this breaks the gradient and if I can make the function general for multiple images and multiple palettes.

The previous implementation does not maintain gradient since it uses argmin. I have coded a new version which does:

def quantize(image, palette):
    """
    Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
    Only works for one image i.e. CHW. Does NOT work for batches.
    """

    C, H, W = image.shape
    n_colors = len(palette)

    # Easier to work with list of colors
    flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]

    # Repeat image so that there are n_color number of columns of the same image
    flat_img_per_color = torch.stack(n_colors*[flat_img], dim=-2) # [H*W, C] -> [H*W, n_colors, C]

    # Get euclidian distance between each pixel in each column and the columns repsective color
    # i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
    distance_per_pixel = torch.sum((flat_img_per_color-palette)**2, dim=-1)**0.5 # [H*W, n_colors, C] -> [H*W, n_colors]

    # Get the shortest distance (one value per row (H*W) is selected)
    min_distances = torch.min(distance_per_pixel, dim=-1).values # [H*W, n_colors] -> [H*W]

    # Get difference between each distance and the shortest distance.
    # One value per column (the selected value) will become 0.
    per_color_difference = distance_per_pixel - torch.stack(n_colors*[min_distances], dim=-1) # [H*W, n_colors]

    # Round all values up and invert. Creates something similar to one-hot encoding.
    per_color_diff_scaled = 1 - torch.ceil(per_color_difference)

    # Multiply the "kinda" one-hot encoded per_color_diff_scaled with the palette colors.
    # The result is a quantized image.
    quantized = torch.matmul(per_color_diff_scaled, palette)

    # Reshape it back to the original input format.
    quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]

    return quantized_img

I am however concerned with how well the gradient is maintained.

I found this post while searching for implementations around argmin():

I am not too familiar with weighted sums, but is it similar to my new implementation?

It still broke gradients in some specific cases. I finally found this thread giving a weird hack that solved the problem:

FInal code if someone has a similar problem in the future:

def quantize(image, palette):
    """
    Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
    Only works for one image i.e. CHW. Does NOT work for batches.
    """

    C, H, W = image.shape
    n_colors = len(palette)

    # Easier to work with list of colors
    flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]

    # Repeat image so that there are n_color number of columns of the same image
    flat_img_per_color = torch.stack(n_colors*[flat_img], dim=-2) # [H*W, C] -> [H*W, n_colors, C]

    # Get euclidian distance between each pixel in each column and the columns repsective color
    # i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
    squared_distance = (flat_img_per_color-palette)**2
    # Dirty cursed hack
    # https://discuss.pytorch.org/t/runtimeerror-function-sqrtbackward-returned-nan-values-in-its-0th-output/48702/4
    euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors]


    # Get the shortest distance (one value per row (H*W) is selected)
    min_distances = torch.min(euclidean_distance, dim=-1).values # [H*W, n_colors] -> [H*W]

    # Get difference between each distance and the shortest distance.
    # One value per column (the selected value) will become 0.
    per_color_difference = euclidean_distance - torch.stack(n_colors*[min_distances], dim=-1) # [H*W, n_colors]

    # Round all values up and invert. Creates something similar to one-hot encoding.
    per_color_diff_scaled = 1 - torch.ceil(per_color_difference)

    per_color_diff_scaled_ = per_color_diff_scaled #(per_color_diff_scaled.T / per_color_diff_scaled.sum(dim=-1)).T

    # Multiply the "kinda" one-hot encoded per_color_diff_scaled with the palette colors.
    # The result is a quantized image.
    quantized = torch.matmul(per_color_diff_scaled_, palette)

    # Reshape it back to the original input format.
    quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]

    return quantized_img