Quantize reduces the number of colors in an image, with an option to a specific set of target colors. I am sure a similar function could be created for tensors but would the gradient be maintained? If yes, any tips on implementing it in an efficient way?
The previous implementation does not maintain gradient since it uses argmin. I have coded a new version which does:
def quantize(image, palette):
"""
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
Only works for one image i.e. CHW. Does NOT work for batches.
"""
C, H, W = image.shape
n_colors = len(palette)
# Easier to work with list of colors
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]
# Repeat image so that there are n_color number of columns of the same image
flat_img_per_color = torch.stack(n_colors*[flat_img], dim=-2) # [H*W, C] -> [H*W, n_colors, C]
# Get euclidian distance between each pixel in each column and the columns repsective color
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
distance_per_pixel = torch.sum((flat_img_per_color-palette)**2, dim=-1)**0.5 # [H*W, n_colors, C] -> [H*W, n_colors]
# Get the shortest distance (one value per row (H*W) is selected)
min_distances = torch.min(distance_per_pixel, dim=-1).values # [H*W, n_colors] -> [H*W]
# Get difference between each distance and the shortest distance.
# One value per column (the selected value) will become 0.
per_color_difference = distance_per_pixel - torch.stack(n_colors*[min_distances], dim=-1) # [H*W, n_colors]
# Round all values up and invert. Creates something similar to one-hot encoding.
per_color_diff_scaled = 1 - torch.ceil(per_color_difference)
# Multiply the "kinda" one-hot encoded per_color_diff_scaled with the palette colors.
# The result is a quantized image.
quantized = torch.matmul(per_color_diff_scaled, palette)
# Reshape it back to the original input format.
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]
return quantized_img
I am however concerned with how well the gradient is maintained.
I found this post while searching for implementations around argmin():
I am not too familiar with weighted sums, but is it similar to my new implementation?
It still broke gradients in some specific cases. I finally found this thread giving a weird hack that solved the problem:
FInal code if someone has a similar problem in the future:
def quantize(image, palette):
"""
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
Only works for one image i.e. CHW. Does NOT work for batches.
"""
C, H, W = image.shape
n_colors = len(palette)
# Easier to work with list of colors
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]
# Repeat image so that there are n_color number of columns of the same image
flat_img_per_color = torch.stack(n_colors*[flat_img], dim=-2) # [H*W, C] -> [H*W, n_colors, C]
# Get euclidian distance between each pixel in each column and the columns repsective color
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
squared_distance = (flat_img_per_color-palette)**2
# Dirty cursed hack
# https://discuss.pytorch.org/t/runtimeerror-function-sqrtbackward-returned-nan-values-in-its-0th-output/48702/4
euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors]
# Get the shortest distance (one value per row (H*W) is selected)
min_distances = torch.min(euclidean_distance, dim=-1).values # [H*W, n_colors] -> [H*W]
# Get difference between each distance and the shortest distance.
# One value per column (the selected value) will become 0.
per_color_difference = euclidean_distance - torch.stack(n_colors*[min_distances], dim=-1) # [H*W, n_colors]
# Round all values up and invert. Creates something similar to one-hot encoding.
per_color_diff_scaled = 1 - torch.ceil(per_color_difference)
per_color_diff_scaled_ = per_color_diff_scaled #(per_color_diff_scaled.T / per_color_diff_scaled.sum(dim=-1)).T
# Multiply the "kinda" one-hot encoded per_color_diff_scaled with the palette colors.
# The result is a quantized image.
quantized = torch.matmul(per_color_diff_scaled_, palette)
# Reshape it back to the original input format.
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]
return quantized_img