I created this code but something is wrong, the input image should learn the best goldfish representation because I use the goldfish ImageNet category which is second category. Starting from the random image I should update the input image for a number of iterations in code.
However, the end result is I get the loss down to zero, but the image still looks like a random noise. Can you please find where I am wrong:
import torch
import torchvision.models as models
import torch.optim as optim
from PIL import Image
import numpy as np
import torch.nn.functional as F # Import functional interface for torch.nn
from torchvision.models.vgg import VGG16_Weights
# Load a pre-trained VGG16 model
vgg16 = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
vgg16.eval() # Set the model to evaluation mode
# Index of the target class (goldfish in this case)
target_class = 1 # Index 1 corresponds to 'goldfish' in ImageNet classes
# Define the optimizer (to update the input image)
input_image = torch.randn(1, 3, 224, 224).requires_grad_(True) # Initialize input image with random noise
optimizer = optim.SGD([input_image], lr=0.05) # Use the input image as parameter to optimize
# Number of optimization steps
num_iterations = 1000
# Optimization loop
for i in range(num_iterations):
print(i)
optimizer.zero_grad() # Zero gradients
# Forward pass through VGG16 with the current input image
output = vgg16(input_image)
# Apply Softmax function to the output tensor
probabilities = F.softmax(output, dim=1)
# Compute the loss - maximize the probability of the target class
loss = -probabilities[0, target_class].log() # Negative log probability of the target class
print(loss)
# Backpropagation: Compute gradients of the input image wrt the loss
loss.backward()
# Update the input image using the gradients (gradient ascent)
optimizer.step()
# Clamp the image pixel values to stay within valid range (0-1)
#input_image.data.clamp_(0, 1)
if i % 10 == 0:
# Convert the optimized input image tensor to a PIL image
generated_image = input_image.squeeze(0).detach().cpu().numpy()
generated_image = np.moveaxis(generated_image, 0, -1) # Change from [C, H, W] to [H, W, C]
generated_image = (generated_image * 255).astype(np.uint8) # Rescale to [0, 255] for PIL
# Save the generated image
image_output = Image.fromarray(generated_image)
image_output.save("generated_goldfish_image" + str(i) + ".jpg")