Here is the opposite value of Wasserstein distance(also the Adversarial loss of GAN)
Epoch:0 batch_num:0 wgan_loss:-16.413176
Epoch:0 batch_num:1 wgan_loss:14472.721
Epoch:0 batch_num:2 wgan_loss:-10957247.0
Epoch:0 batch_num:3 wgan_loss:-455000130.0
Epoch:0 batch_num:4 wgan_loss:-3773285000.0
I’ve tried to lower the learning rate with batch size limited to 20. The following code is what I’ve used for the discriminator of WGAN, which is integrated with Gradient Reversal Layer(GRL). Moreover, the calc_coeff function just determines the ratio of reversed gradient during the back-propagation.
def grl_hook(coeff):
def fun1(grad):
return -coeff*grad.clone()
return fun1
def calc_coeff(iter_num, high=1.0, low=0.0, alpha=2.0, max_iter=50.0):
return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)
class DiscriminatorforWGAN(nn.Module):
def __init__(self, in_feature, hidden_size):
super(AdversarialNetworkforCDAN, self).__init__()
self.ad_layer1 = nn.Linear(in_feature, hidden_size)
self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
self.ad_layer3 = nn.Linear(hidden_size, 1)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.dropout1 = nn.Dropout(0.2)
self.dropout2 = nn.Dropout(0.2)
self.iter_num = -1
self.alpha = 1.0
self.low = 0.0
self.high = 1.0
self.max_iter = 15.0
self.coeff = np.float(0.02)
def forward(self, x):
if self.training:
self.iter_num += 1
if self.iter_num >= self.max_iter:
self.iter_num = self.max_iter
coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter)
self.coeff = coeff
x = x * 1.0
x.register_hook(grl_hook(coeff))
x = self.ad_layer1(x)
x = self.relu1(x)
x = self.dropout1(x)
x = self.ad_layer2(x)
x = self.relu2(x)
x = self.dropout2(x)
y = self.ad_layer3(x)
return y
As for the generator of WGAN, it’s just a simple CNN-based network. The other requirements of WGAN(such as clamping the parameters of the discriminator and the choice of RMSprop for WGAN) have been strictly followed!
The WGAN loss of mine is as follows:
def wgan_loss(values_from_target_side, values_from_source_side):
W_loss = -torch.mean(values_from_target_side) + torch.mean(values_from_source_side)
return W_loss