Hi Team,
I am using the below PIL function to find the most dominant color in a image in training loop of my Pytorch script,
How could I convert the function to pytorch or a Autograd function
Program
def get_dominant_color(img):
width, height = img.size
# The RGB values we will “snap” to
colors = [255, 223, 191, 159, 127, 95, 63, 31, 0]
original_color_count = {}
color_count = {}
# Loop through every pixel in the image and modify it
for w in range(width):
for h in range(height):
current_color = img.getpixel((w, h))
if current_color in original_color_count:
original_color_count[current_color] += 1
else:
original_color_count[current_color] = 1
r, g, b = current_color
r_set = False
g_set = False
b_set = False
# Loop through our allowed values and find the closest value to snap to
for i in range(len(colors)):
color_one = colors[i]
color_two = colors[i + 1]
if not r_set:
if color_one >= r >= color_two:
distance_one = color_one - r
distance_two = r - color_two
r = color_one if distance_one <= distance_two else color_two
r_set = True
if not g_set:
if color_one >= g >= color_two:
distance_one = color_one - g
distance_two = g - color_two
g = color_one if distance_one <= distance_two else color_two
g_set = True
if not b_set:
if color_one >= b >= color_two:
distance_one = color_one - b
distance_two = b - color_two
b = color_one if distance_one <= distance_two else color_two
b_set = True
if all((r_set, g_set, b_set)):
break
# Set our new pixel back on the image to see the difference
new_rgb = (r, g, b)
img.putpixel((w, h), new_rgb)
if new_rgb in color_count:
color_count[new_rgb] += 1
else:
color_count[new_rgb] = 1
# Count and sort the colors
all_colors = color_count.items()
all_colors = sorted(all_colors, key=lambda tup: tup[1], reverse=True)
all_colors = all_colors[1:]
filtered_colors = [color for color in all_colors]
if len(filtered_colors) !=0:
dominant_color = filtered_colors[0][0]
dominant_color = torch.tensor(dominant_color,dtype=torch.float32)
dominant_color = config.Generate_color(dominant_color)
else:
dominant_color = torch.tensor([255,0,0],dtype=torch.float32)
dominant_color = config.Generate_color(dominant_color)
return dominant_color