I wanted to make WGAN-gp generating samples of Mixture of gaussian model. So I sampled from Mixture of gaussian model like this:
import torch
import torch.distributions as D
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as data
import torch.nn as nn
def samplling_from_MOG(n_sample,pi,mu,std):
uni = D.Uniform(0,1)
loc = uni.sample(sample_shape = [n_sample])
n_each_gau_sample = []
sum_pi = 0
for i in range(len(pi)):
if i == 0:
n_each_gau_sample.append(len(loc[loc<=pi[i]]))
elif i == len(pi)-1:
n_each_gau_sample.append(n_sample-sum(n_each_gau_sample))
else:
n_each_gau_sample.append(len(loc[(loc>sum_pi)&(loc<=(pi[i]+sum_pi))]))
sum_pi += pi[i]
sample_data = np.array([])
for i in range(len(mu)):
Gau = D.Normal(torch.tensor([mu[i]]), torch.tensor([std[i]]))
sample_data = np.concatenate((sample_data, Gau.sample(sample_shape = [n_each_gau_sample[i]]).numpy()), axis=None)
return sample_data
sampled_data = samplling_from_MOG(10000,[0.1,0.5,0.4],[1.0,2.0,5.0],[0.2,0.2,1])
plt.hist(sampled_data,bins=100,density=True)
plt.show()
and this is sample result : enter image description here and below is my WGAN-gp model
class BasicDataset(data.Dataset):
def __init__(self, x_tensor):
super(BasicDataset, self).__init__()
self.x = x_tensor
def __getitem__(self, index):
return self.x[index]
def __len__(self):
return len(self.x)
tr_dataset = BasicDataset(sampled_data)
data_loader = data.DataLoader(dataset=tr_dataset, batch_size=32, shuffle=True)
class Generator(nn.Module):
def __init__(self,latent_dim):
super(Generator, self).__init__()
self.fc1 = nn.Linear(latent_dim, 400)
self.bn1 = torch.nn.BatchNorm1d(400)
self.fc2 = nn.Linear(400, 300)
self.bn2 = torch.nn.BatchNorm1d(300)
self.fc3 = nn.Linear(300, 1)
def forward(self, z):
L1 = torch.relu(self.fc1(z))
L1 = self.bn1(L1)
L2 = torch.relu(self.fc2(L1))
L2 = self.bn2(L2)
output = self.fc3(L2)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(1, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, z):
L1 = torch.relu(self.fc1(z))
L2 = torch.relu(self.fc2(L1))
output = self.fc3(L2)
return output
def compute_gradient_penalty(D, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = torch.tensor(np.random.random((real_samples.size(0), 1)))
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True).float()
d_interpolates = D(interpolates)
fake = torch.empty(d_interpolates.shape[0], 1, dtype=torch.float)
fake.fill_(1.0)
fake.requires_grad=False
# Get gradient w.r.t. interpolates
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
training code
# Loss weight for gradient penalty
lambda_gp = 1
latent_dim = 100
lr = 0.0002
# Initialize generator and discriminator
generator = Generator(latent_dim)
discriminator = Discriminator()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)
n_epochs = 100
for epoch in range(n_epochs):
for i, real_data in enumerate(data_loader):
# Configure input
real_data = torch.tensor(real_data).unsqueeze(dim=1).float()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Sample noise as generator input
z = torch.tensor(np.random.normal(0, 1, (real_data.shape[0], latent_dim))).float()
# Generate a batch of images
fake_data = generator(z)
# Real images
real_validity = discriminator(real_data)
# Fake images
fake_validity = discriminator(fake_data)
# Gradient penalty
gradient_penalty = compute_gradient_penalty(discriminator, real_data, fake_data)
# Adversarial loss
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
d_loss.backward()
optimizer_D.step()
optimizer_G.zero_grad()
# Train the generator every n_critic steps
if i % 5 == 0:
# -----------------
# Train Generator
# -----------------
# Generate a batch of images
fake_data = generator(z)
# Loss measures generator's ability to fool the discriminator
# Train on fake images
fake_validity = discriminator(fake_data)
g_loss = -torch.mean(fake_validity)
g_loss.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, n_epochs, i, len(data_loader), d_loss.item(), g_loss.item())
)
result plot code
# Sample noise as generator input
z = torch.tensor(np.random.normal(0, 1, (10000, latent_dim))).float()
# Generate a batch of images
fake_data = generator(z)
plt.hist(fake_data.detach().numpy(),bins=100,density=True)
plt.show()
this is my generator result enter image description here but my result is bad… I don’t know what is wrong. Please help me.