The previous implementation does not maintain gradient since it uses argmin
. I have coded a new version which does:
def quantize(image, palette):
"""
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
Only works for one image i.e. CHW. Does NOT work for batches.
"""
C, H, W = image.shape
n_colors = len(palette)
# Easier to work with list of colors
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]
# Repeat image so that there are n_color number of columns of the same image
flat_img_per_color = torch.stack(n_colors*[flat_img], dim=-2) # [H*W, C] -> [H*W, n_colors, C]
# Get euclidian distance between each pixel in each column and the columns repsective color
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
distance_per_pixel = torch.sum((flat_img_per_color-palette)**2, dim=-1)**0.5 # [H*W, n_colors, C] -> [H*W, n_colors]
# Get the shortest distance (one value per row (H*W) is selected)
min_distances = torch.min(distance_per_pixel, dim=-1).values # [H*W, n_colors] -> [H*W]
# Get difference between each distance and the shortest distance.
# One value per column (the selected value) will become 0.
per_color_difference = distance_per_pixel - torch.stack(n_colors*[min_distances], dim=-1) # [H*W, n_colors]
# Round all values up and invert. Creates something similar to one-hot encoding.
per_color_diff_scaled = 1 - torch.ceil(per_color_difference)
# Multiply the "kinda" one-hot encoded per_color_diff_scaled with the palette colors.
# The result is a quantized image.
quantized = torch.matmul(per_color_diff_scaled, palette)
# Reshape it back to the original input format.
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]
return quantized_img
I am however concerned with how well the gradient is maintained.
I found this post while searching for implementations around argmin()
:
I am not too familiar with weighted sums, but is it similar to my new implementation?