Convert the below PIL based library function to Pytorch in training Loop

Hi Team,

I am using the below PIL function to find the most dominant color in a image in training loop of my Pytorch script,

How could I convert the function to pytorch or a Autograd function :slight_smile:

Program

def get_dominant_color(img):
width, height = img.size
# The RGB values we will “snap” to
colors = [255, 223, 191, 159, 127, 95, 63, 31, 0]

original_color_count = {}
color_count = {}
# Loop through every pixel in the image and modify it
for w in range(width):
    for h in range(height):
        current_color = img.getpixel((w, h))

        if current_color in original_color_count:
            original_color_count[current_color] += 1
        else:
            original_color_count[current_color] = 1

        r, g, b = current_color
        r_set = False
        g_set = False
        b_set = False

        #  Loop through our allowed values and find the closest value to snap to
        for i in range(len(colors)):
            color_one = colors[i]
            color_two = colors[i + 1]

            if not r_set:
                if color_one >= r >= color_two:
                    distance_one = color_one - r
                    distance_two = r - color_two
                    r = color_one if distance_one <= distance_two else color_two
                    r_set = True

            if not g_set:
                if color_one >= g >= color_two:
                    distance_one = color_one - g
                    distance_two = g - color_two
                    g = color_one if distance_one <= distance_two else color_two
                    g_set = True

            if not b_set:
                if color_one >= b >= color_two:
                    distance_one = color_one - b
                    distance_two = b - color_two
                    b = color_one if distance_one <= distance_two else color_two
                    b_set = True

            if all((r_set, g_set, b_set)):
                break

        # Set our new pixel back on the image to see the difference
        new_rgb = (r, g, b)
        img.putpixel((w, h), new_rgb)

        if new_rgb in color_count:
            color_count[new_rgb] += 1
        else:
            color_count[new_rgb] = 1


# Count and sort the colors
all_colors = color_count.items()
all_colors = sorted(all_colors, key=lambda tup: tup[1], reverse=True)
all_colors = all_colors[1:]

filtered_colors = [color for color in all_colors]

if len(filtered_colors) !=0:
    dominant_color = filtered_colors[0][0]
    dominant_color = torch.tensor(dominant_color,dtype=torch.float32)
    dominant_color = config.Generate_color(dominant_color)
else:
    dominant_color = torch.tensor([255,0,0],dtype=torch.float32)
    dominant_color = config.Generate_color(dominant_color)
return dominant_color