Guys, I’m getting this error while trying to run a Vanilla Gan on the SVHN dataset using PyTorch, could someone help me?
#dataset
train_dataset = torchvision.datasets.SVHN(‘./data’,
transform=torchvision.transforms.ToTensor(),
download=True)
#define generator
class Generator(nn.Module):
def init(self, latent_dims):
super(Generator, self).init()
self.latent_dims = latent_dims
self.net = nn.Sequential(
nn.Linear(self.latent_dims, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 32*32),
nn.Tanh(),
)
def forward(self, x):
return self.net(x).view(-1, 1, 32, 32)
#define discriminator
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.n_input = 32*32
self.net = nn.Sequential(
nn.Linear(self.n_input, 2048),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(2048, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = torch.flatten(x, start_dim=1)
return self.net(x)
#define GAN
class VanillaGAN:
def init(self, latent_dims, criterion):
self.latent_dims = latent_dims
self.generator = Generator(latent_dims)
self.discriminator = Discriminator()
self.criterion = criterion
self.optimizer_g = optim.Adam(self.generator.parameters(), lr=3e-4)
self.optimizer_d = optim.Adam(self.discriminator.parameters(), lr=3e-4)
def to(self, device):
self.generator.to(device)
self.discriminator.to(device)
def train_discriminator(self, real_data, fake_data):
batch_size = real_data.size(0)
real_label = torch.ones(batch_size, 1).to(device) # real labels (1s)
fake_label = torch.zeros(batch_size, 1).to(device) # fake labels (0s)
self.optimizer_d.zero_grad()
output_real = self.discriminator(real_data)
loss_real = criterion(output_real, real_label)
output_fake = self.discriminator(fake_data)
loss_fake = criterion(output_fake, fake_label)
loss_real.backward()
loss_fake.backward()
self.optimizer_d.step()
return (loss_real + loss_fake) * 0.5
def train_generator(self, fake_data):
batch_size = fake_data.size(0)
real_label = torch.ones(batch_size, 1).to(device) # real labels (1s)
self.optimizer_g.zero_grad()
output = self.discriminator(fake_data)
loss = self.criterion(output, real_label)
loss.backward()
self.optimizer_g.step()
return loss
#function to create the noise vector
def create_noise(sample_size, latent_dims):
return torch.randn(sample_size, latent_dims).to(device)
def train(model, train_loader, epochs=20):
log_dict = {“train_generator_loss”: [],
“train_discriminator_loss”: [],
“reconstructed_images”:[]}
for epoch in tqdm.tqdm_notebook(range(epochs)):
loss_g = 0.0
loss_d = 0.0
model.discriminator.train()
model.generator.train()
for data in tqdm.tqdm_notebook(train_loader):
image, _ = data
image = image.to(device)
batch_size = len(image)
# train the discriminator network
random_noise = create_noise(batch_size, model.latent_dims)
fake_data = model.generator(random_noise).detach() # detach() prevents from training generator
real_data = image
loss_d += model.train_discriminator(real_data, fake_data)
# train the generator network
random_noise = create_noise(batch_size, model.latent_dims)
fake_data = model.generator(random_noise)
loss_g += model.train_generator(fake_data)
log_dict["train_generator_loss"].append(loss_g.item()/len(train_loader))
log_dict["train_discriminator_loss"].append(loss_d.item()/len(train_loader))
log_dict["reconstructed_images"].append(fake_data.to('cpu').detach()[:5])
log_dict["original_images"] = image.to('cpu').detach()[:5]
return model, log_dict
criterion = nn.BCELoss()
latent_dims = 128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
model = VanillaGAN(latent_dims, criterion)
model.to(device)
model, log_dict = train(model, train_loader, epochs=300)
Im getting this error:
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:6: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0
Please use tqdm.notebook.tqdm
instead of tqdm.tqdm_notebook
0%
0/300 [00:00<?, ?it/s]
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0
Please use tqdm.notebook.tqdm
instead of tqdm.tqdm_notebook
This is added back by InteractiveShellApp.init_path()
0%
0/573 [00:00<?, ?it/s]
RuntimeError Traceback (most recent call last)
in
6 model = VanillaGAN(latent_dims, criterion)
7 model.to(device)
----> 8 model, log_dict = train(model, train_loader, epochs=300)
7 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) → Tensor:
→ 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) → str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x3072 and 1024x2048)