I am working on WGAN-GP with a data set having 292 images for training. I am facing few problems:
1- Generator’s graph seems totally flat even after 25 epochs.
2- Just after few epochs this error is coming “DataLoader worker (pid(s) 6740) exited unexpectedly”
Error:
E:\Users\Asus\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
1001 if len(failed_workers) > 0:
1002 pids_str = ', '.join(str(w.pid) for w in failed_workers)
-> 1003 raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
1004 if isinstance(e, queue.Empty):
1005 return (False, None)
RuntimeError: DataLoader worker (pid(s) 6740) exited unexpectedly
3- Continuously flat noise images are generating. Image is attached for further clarity.
Training Loop:
if epoch_flag == True:
previous_epochs = previous_epochs
elif epoch_flag == False:
previous_epochs = 0
cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
# Dataloader returns the batches
for real_images, _ in tqdm(data_loader):
cur_batch_size = len(real_images)
real_images = real_images.to(device)
mean_iteration_critic_loss = 0
for _ in range(crit_repeats):
### Update critic ###
crit_opt.zero_grad()
fake_noise = get_noise(cur_batch_size, z_dim, device=device)
fake = gen(fake_noise)
crit_fake_pred = crit(fake.detach())
crit_real_pred = crit(real_images)
epsilon = torch.rand(len(real_images), 1, 1, 1, device=device, requires_grad=True)
gradient = get_gradient(crit, real_images, fake.detach(), epsilon)
gp = gradient_penalty(gradient)
crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)
# Keep track of the average critic loss in this batch
mean_iteration_critic_loss += crit_loss.item() / crit_repeats
# Update gradients
crit_loss.backward(retain_graph=True)
# Update optimizer
crit_opt.step()
critic_losses += [mean_iteration_critic_loss]
### Update generator ###
gen_opt.zero_grad()
fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
fake_2 = gen(fake_noise_2)
crit_fake_pred = crit(fake_2)
gen_loss = get_gen_loss(crit_fake_pred)
gen_loss.backward()
# Update the weights
gen_opt.step()
# Keep track of the average generator loss
generator_losses += [gen_loss.item()]
### Visualization code ###
if cur_step % display_step == 0 and cur_step >= 0:
gen_mean = sum(generator_losses[-display_step:]) / display_step
crit_mean = sum(critic_losses[-display_step:]) / display_step
print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
show_tensor_images(fake)
show_tensor_images(real_images)
step_bins = 20
num_examples = (len(generator_losses) // step_bins) * step_bins
plt.plot(
range(num_examples // step_bins),
torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
label="Generator Loss"
)
plt.plot(
range(num_examples // step_bins),
torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
label="Critic Loss"
)
plt.legend()
plt.show()
cur_step += 1
all_epochs = epoch + 1 + previous_epochs
if epoch_flag == True:
save_fake_images(all_epochs)
elif epoch_flag == False:
save_fake_images(epoch + 1)
Generator’s Code:
class Generator(nn.Module):
def __init__(self, z_dim=32, im_chan=1, hidden_dim=32):
super(Generator, self).__init__()
self.z_dim = z_dim
# Build the neural network
self.gen = nn.Sequential(
#PrintBlock(), # [50, 32, 1, 1]
self.make_gen_block(z_dim, hidden_dim * 2),
#PrintBlock(), # [50, 64, 3, 3]
self.make_gen_block(hidden_dim * 2, hidden_dim * 4, kernel_size=4, stride=1),
#PrintBlock(), # [50, 128, 6, 6]
self.make_gen_block(hidden_dim * 4, hidden_dim * 8),
#PrintBlock(), # [50, 256, 13, 13]
self.make_gen_block(hidden_dim * 8, hidden_dim * 16, kernel_size=4, stride=1),
#PrintBlock(), # [50, 512, 16, 16]
self.make_gen_block(hidden_dim * 16, hidden_dim * 16),
#PrintBlock(), # [50, 512, 33, 33]
self.make_gen_block(hidden_dim * 16 , hidden_dim * 8 , kernel_size=4, stride=1),
#PrintBlock(), # [50, 256, 36, 36]
self.make_gen_block(hidden_dim * 8 , hidden_dim * 4),
#PrintBlock(), # [50, 128, 73, 73]
self.make_gen_block(hidden_dim * 4 , hidden_dim * 2, kernel_size=4, stride=1),
#PrintBlock(), # [50, 64, 76, 76]
self.make_gen_block(hidden_dim * 2 , im_chan, final_layer=True),
#PrintBlock(), # [50, 1, 153, 153]
)
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh(),
)
def forward(self, noise):
x = noise.view(len(noise), self.z_dim, 1, 1)
return self.gen(x)
Discriminator’s Code:
class Critic(nn.Module):
def __init__(self, im_chan=1, hidden_dim=64):
super(Critic, self).__init__()
self.crit = nn.Sequential(
#PrintBlock(),
self.make_crit_block(im_chan, hidden_dim),
#PrintBlock(),
self.make_crit_block(hidden_dim, hidden_dim * 2, kernel_size=4, stride=1),
#PrintBlock(),
self.make_crit_block(hidden_dim * 2, hidden_dim * 4),
#PrintBlock(),
self.make_crit_block(hidden_dim * 4, hidden_dim * 8,kernel_size=4, stride=1),
#PrintBlock(),
self.make_crit_block(hidden_dim * 8, hidden_dim * 8),
#PrintBlock(),
self.make_crit_block(hidden_dim * 8, hidden_dim * 4,kernel_size=4, stride=1),
#PrintBlock(),
self.make_crit_block(hidden_dim * 4, hidden_dim * 2),
#PrintBlock(),
self.make_crit_block(hidden_dim * 2, hidden_dim,kernel_size=4, stride=1),
#PrintBlock(),
self.make_crit_block(hidden_dim, 1, final_layer=True),
#PrintBlock(),
)
def make_crit_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2, inplace=True),
)
else:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image):
crit_pred = self.crit(image)
return crit_pred.view(len(crit_pred), -1)
Please help me in figuring out these issues.