I’m trying to train an ACGAN on cifar-10. However, I get the following error which I can’t make sense of. I think there is an error in x = self.fc_layer(input)
of the Generator.
Traceback (most recent call last):
File "/Users/.../acgan.py", line 190, in <module>
fake_images = G(latent) #generate a fake image
File "/Users/.../opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/.../acgan.py", line 63, in forward
x = self.conv1(x)
File "/Users/.../opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/.../opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 956, in forward
return F.conv_transpose2d(
RuntimeError: non-positive stride is not supported
this is the code:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
def save_image_grid(image_tensor, img_name, num_images=100, size=(3, 64, 64)):
image_unflat = image_tensor.detach().cpu()
image_grid = make_grid(image_unflat[:num_images], nrow=10)
save_image(image_grid, img_name, normalize = True)
batch_size = 100
latent_size = 100
lr = 0.0002
num_epochs = 500
n_classes = len(classes)
beta_1 = 0.5
beta_2 = 0.999
class Generator(nn.Module):
def __init__(self , nb_filter, n_classes):
super(Generator, self).__init__()
self.fc_layer = nn.Linear(latent_size+n_classes, nb_filter * 8)
self.conv1 = nn.ConvTranspose2d(nb_filter * 8, nb_filter * 4, 1, 0)
self.bn1 = nn.BatchNorm2d(nb_filter * 8)
self.conv2 = nn.ConvTranspose2d(nb_filter * 8, nb_filter * 4, 4, 2, 1)
self.bn2 = nn.BatchNorm2d(nb_filter * 4)
self.conv3 = nn.ConvTranspose2d(nb_filter * 4, nb_filter * 2, 4, 2, 1)
self.bn3 = nn.BatchNorm2d(nb_filter * 2)
self.conv4 = nn.ConvTranspose2d(nb_filter * 2, nb_filter * 1, 4, 2, 1)
self.bn4 = nn.BatchNorm2d(nb_filter * 1)
self.conv5 = nn.ConvTranspose2d(nb_filter * 1, 3, 4, 2, 1)
self.__initialize_weights()
def forward(self, input):
x = self.fc_layer(input)
x = x.view(x.size(0), -1, 1, 1)
print(x.size()) ## → (torch.Size([100, 512, 1, 1])
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = F.relu(x)
x = self.conv5(x)
return torch.tanh(x)
def __initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Discriminator(nn.Module):
def __init__(self, nb_filter, num_classes=10):
super(Discriminator, self).__init__()
self.nb_filter = nb_filter
self.conv1 = nn.Conv2d(3, nb_filter, 4, 2, 1)
self.conv2 = nn.Conv2d(nb_filter, nb_filter * 2, 4, 2, 1)
self.bn2 = nn.BatchNorm2d(nb_filter * 2)
self.conv3 = nn.Conv2d(nb_filter * 2, nb_filter * 4, 4, 2, 1)
self.bn3 = nn.BatchNorm2d(nb_filter * 4)
self.conv4 = nn.Conv2d(nb_filter * 4, nb_filter * 8, 4, 2, 1)
self.bn4 = nn.BatchNorm2d(nb_filter * 8)
self.conv5 = nn.Conv2d(nb_filter * 8, nb_filter * 1, 4, 1, 0)
self.gan_linear = nn.Linear(nb_filter * 1, 1)
self.aux_linear = nn.Linear(nb_filter * 1, num_classes)
self.__initialize_weights()
def forward(self, input):
x = self.conv1(input)
x = F.leaky_relu(x, 0.2)
x = self.conv2(x)
x = self.bn2(x)
x = F.leaky_relu(x, 0.2)
x = self.conv3(x)
x = self.bn3(x)
x = F.leaky_relu(x, 0.2)
x = self.conv4(x)
x = self.bn4(x)
x = F.leaky_relu(x, 0.2)
x = self.conv5(x)
x = x.view(-1, self.nb_filter * 1)
c = self.aux_linear(x)
s = self.gan_linear(x)
s = torch.sigmoid(s)
return s.squeeze(1), c.squeeze(1)
def __initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = 'cpu'
transform = transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
training_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=100, shuffle=True)
D = Discriminator(64, n_classes).to(device)
G = Generator(64, n_classes).to(device)
# Adam optimization
optimizerD = torch.optim.Adam(D.parameters(), lr, betas = (beta_1, beta_2))
optimizerG = torch.optim.Adam(G.parameters(), lr, betas = (beta_1, beta_2))
criterion_adv = nn.BCELoss()
criterion_aux = nn.CrossEntropyLoss()
total_step = len(train_loader)
for epoch in range(num_epochs):
for batch_idx, (x, target) in enumerate(train_loader):
images = x.to(device)
current_batchSize = images.size()[0]
realLabel = torch.ones(current_batchSize).to(device)
fakeLabel = torch.zeros(current_batchSize).to(device)
target = target.to(device)
cls_one_hot = torch.zeros(current_batchSize,n_classes,device=device)
cls_one_hot[torch.arange(current_batchSize), target] = 1.0
# TRAIN D
# On true data
predictR, predictRLabel = D(images) # image from the real dataset
loss_real_adv = criterion_adv(predictR, realLabel)
loss_real_aux = criterion_aux(predictRLabel, target)
real_score = predictR
# On fake data
latent_value = torch.randn(current_batchSize, latent_size).to(device)
latent = torch.cat((latent_value,cls_one_hot),dim=1)
fake_images = G(latent) # generate a fake image
predictF, predictFLabel = D(fake_images)
loss_fake_adv = criterion_adv(predictF, fakeLabel)
loss_fake_aux = criterion_aux(predictFLabel, gen_labels)
fake_score = predictF
lossD = loss_real_adv + loss_real_aux + loss_fake_adv + loss_fake_aux
optimizerD.zero_grad()
optimizerG.zero_grad()
lossD.backward()
optimizerD.step()
# TRAIN G
latent_value = torch.randn(current_batchSize, latent_size).to(device)
# gen_labels = torch.LongTensor(np.random.randint(0, n_classes, current_batchSize)).to(device)
# random one hot class
cls_one_hot = F.one_hot(torch.arange(0, 10), 10)
cls_one_hot = cls_one_hot[np.random.randint(0, n_classes, current_batchSize)]
fake_images= G(latent_value, cls_one_hot) # generate a fake image
predictG, predictLabel = D(fake_images)
lossG_adv = criterion_adv(predictG, realLabel)
lossG_aux = criterion_aux(predictLabel, gen_labels)
lossG = lossG_adv + lossG_aux
optimizerD.zero_grad()
optimizerG.zero_grad()
lossG.backward()
optimizerG.step()
if (batch_idx+1) % 50 == 0:
print('epoch: '+str(epoch+1)+'/'+str(num_epochs)+ ' batch: '+ str(batch_idx+1)+'/'+str(total_step), ' G loss: '+str(round(lossG.item(), 3))+ ' D loss: '+str(round(lossD.item(), 3)))
if (batch_idx+1) % 100 == 0:
with torch.no_grad():
save_image_grid(fake_images, 'samples_training/epoch {} step {}.png'.format(epoch+1, batch_idx+1))
if (epoch+1) == 1:
save_image_grid(images, 'real images.png')
if (epoch+1) % 50 == 0:
torch.save(G.state_dict(),'training_acgan_cifar_epoch {}.pt'.format(epoch+1))
This is the Generator architecture:
Generator(
(fc_layer): Linear(in_features=110, out_features=512, bias=True)
(conv1): ConvTranspose2d(512, 256, kernel_size=(1, 1), stride=(0, 0))
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv4): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv5): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
Thankful for any help!