# Color quantization?

Is there anything similar to PIL’s Image.quantize() in PyTorch?

Quantize reduces the number of colors in an image, with an option to a specific set of target colors. I am sure a similar function could be created for tensors but would the gradient be maintained? If yes, any tips on implementing it in an efficient way?

For now I created the following methods:

``````def distance(color_a, color_b):
differences = torch.stack([(c_a - c_b) ** 2 for c_a, c_b in zip(color_a, color_b)])

def quantize(image, palette):
# Assume RGB
flat_img = image.view(1, 3,-1) # [C, H, W] -> [1, C, H*W]
img_per_palette_color = torch.cat(len(palette)*[flat_img]) # [1, C, H*W] -> [n_colors, C, H*W]
distance_per_pixel = distance(img_per_palette_color, palette.unsqueeze(-1)) # [n_colors, C, H*W] -> [n_colors, H*W]
color_indices = torch.argmin(distance_per_pixel, dim=0) # [n_colors, H*W] -> [H*W]
new_colors = palette[color_indices].T # [H*W] -> [C, H*W]
return new_colors.view(image.shape) # [C, H*W] -> [C, H, W]

colors = torch.tensor([
[0, 75, 135],
[255, 205, 0]
])
colors_normalized = colors/255

original = TF.to_tensor(Image.open(r"[PATH]").convert("RGB"))
display(TF.to_pil_image(original))

quantized = quantize(original, colors_normalized)
display(TF.to_pil_image(quantized))
``````

Which outputs:

(1st one is the original, 2nd one is quantized)

I will investigate if this breaks the gradient and if I can make the function general for multiple images and multiple palettes.

The previous implementation does not maintain gradient since it uses `argmin`. I have coded a new version which does:

``````def quantize(image, palette):
"""
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
Only works for one image i.e. CHW. Does NOT work for batches.
"""

C, H, W = image.shape
n_colors = len(palette)

# Easier to work with list of colors
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]

# Repeat image so that there are n_color number of columns of the same image
flat_img_per_color = torch.stack(n_colors*[flat_img], dim=-2) # [H*W, C] -> [H*W, n_colors, C]

# Get euclidian distance between each pixel in each column and the columns repsective color
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
distance_per_pixel = torch.sum((flat_img_per_color-palette)**2, dim=-1)**0.5 # [H*W, n_colors, C] -> [H*W, n_colors]

# Get the shortest distance (one value per row (H*W) is selected)
min_distances = torch.min(distance_per_pixel, dim=-1).values # [H*W, n_colors] -> [H*W]

# Get difference between each distance and the shortest distance.
# One value per column (the selected value) will become 0.
per_color_difference = distance_per_pixel - torch.stack(n_colors*[min_distances], dim=-1) # [H*W, n_colors]

# Round all values up and invert. Creates something similar to one-hot encoding.
per_color_diff_scaled = 1 - torch.ceil(per_color_difference)

# Multiply the "kinda" one-hot encoded per_color_diff_scaled with the palette colors.
# The result is a quantized image.
quantized = torch.matmul(per_color_diff_scaled, palette)

# Reshape it back to the original input format.
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]

return quantized_img
``````

I am however concerned with how well the gradient is maintained.

I found this post while searching for implementations around `argmin()`:

I am not too familiar with weighted sums, but is it similar to my new implementation?

It still broke gradients in some specific cases. I finally found this thread giving a weird hack that solved the problem:

FInal code if someone has a similar problem in the future:

``````def quantize(image, palette):
"""
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
Only works for one image i.e. CHW. Does NOT work for batches.
"""

C, H, W = image.shape
n_colors = len(palette)

# Easier to work with list of colors
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]

# Repeat image so that there are n_color number of columns of the same image
flat_img_per_color = torch.stack(n_colors*[flat_img], dim=-2) # [H*W, C] -> [H*W, n_colors, C]

# Get euclidian distance between each pixel in each column and the columns repsective color
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
squared_distance = (flat_img_per_color-palette)**2
# Dirty cursed hack
# https://discuss.pytorch.org/t/runtimeerror-function-sqrtbackward-returned-nan-values-in-its-0th-output/48702/4
euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors]

# Get the shortest distance (one value per row (H*W) is selected)
min_distances = torch.min(euclidean_distance, dim=-1).values # [H*W, n_colors] -> [H*W]

# Get difference between each distance and the shortest distance.
# One value per column (the selected value) will become 0.
per_color_difference = euclidean_distance - torch.stack(n_colors*[min_distances], dim=-1) # [H*W, n_colors]

# Round all values up and invert. Creates something similar to one-hot encoding.
per_color_diff_scaled = 1 - torch.ceil(per_color_difference)

per_color_diff_scaled_ = per_color_diff_scaled #(per_color_diff_scaled.T / per_color_diff_scaled.sum(dim=-1)).T

# Multiply the "kinda" one-hot encoded per_color_diff_scaled with the palette colors.
# The result is a quantized image.
quantized = torch.matmul(per_color_diff_scaled_, palette)

# Reshape it back to the original input format.
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]

return quantized_img
``````