I’ve just started trying to make a GAN that trains off a folder of images on my computer. I’m extremely new to both PyTorch and Python as a programming language, and when I try to run train.py I keep getting the error:
‘ValueError: Using a target size (torch.Size([128, 1])) that is different to the input size (torch.Size([2048, 1])) is deprecated. Please ensure they have the same size.’
After messing around with some of the values I understand a bit more I managed to change the first number in each torch.Size[here, 1], but whenever I’d manage to alter one, the other would change in a way that I couldn’t understand. by changing batch_size i’d change the target size value in the error, but increasing it by even 10 or 20 seemed to jump the input size value up by hundreds.
I’d change the input size error value by changing the Resize() value on line 20, and while it did change only the one error value, I couldn’t find any logical correlation to what the resize numbers would be, and what the error value would end up as.
Here’s the code in question:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import numpy as np
from models.gan import Generator, Discriminator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
num_epochs = 200
batch_size = 128
learning_rate = 0.0002
latent_size = 100
# Data loading and preprocessing
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = torchvision.datasets.ImageFolder(root='data/', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()
# Loss function and optimizers
criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
# Start training
for epoch in range(num_epochs):
for i, (images, _) in enumerate(dataloader):
# Train the discriminator
discriminator.zero_grad()
# Generate fake images
noise = torch.randn(batch_size, latent_size)
fake_images = generator(noise)
# Compute the loss for real images
real_images = images
real_labels = torch.ones((batch_size, 1), dtype=torch.float32, device=device)
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
# Compute the loss for fake images
fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32, device=device)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
# Backpropagate the total loss and update the parameters
d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
# Train the generator
generator.zero_grad()
# Generate fake images and compute the loss
noise = torch.randn(batch_size, latent_size)
fake_images = generator(noise)
fake_labels = torch.ones((batch_size, 1), dtype=torch.float32, device=device)
fake_outputs = discriminator(fake_images)
g_loss = criterion(fake_outputs, fake_labels)
# Backpropagate the loss and update the parameters
g_loss.backward()
g_optimizer.step()
if (i+1) % 100 == 0:
# Print the loss every 100 steps
print("Epoch [{}/{}], Step [{}/{}], D_Loss: {:.4f}, G_Loss: {:.4f}"
.format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))
# Save the generated images at the end of each epoch
if (epoch+1) % 10 == 0:
save_path = 'samples/' + str(epoch+1)
if not os.path.exists(save_path):
os.makedirs(save_path)
with torch.no_grad():
noise = torch.randn(25, latent_size)
fake_images = generator(noise)
fake_images = fake_images.reshape(25, 3, 128, 128)
torchvision.utils.save_image(fake_images, save_path + '/sample.png', nrow=5, normalize=True)
# Save the model after training
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
print('Training completed!')
I’ve checked a few similar forums but I generally can’t make much sense of things without asking weirdly specific questions or getting clarifications that others might find unnecessary.