I’m trying to implement a Pytorch version of Creative Adversarial Networks, a GAN with a modified/custom loss function.
Here are the formulae for the loss function. I’m currently using Pytorch’s
nn.CrossEntropyLoss for the discriminator’s modified loss function, but I don’t think
nn.CrossEntropyLoss is suitable for the generator, as
nn.CrossEntropyLoss seems to expect
Long and not
Float tensors, and the paper’s loss function, particularly the generator’s loss, seems to me like it would require floats, as one multiplies the loss by (1/K), where K is the number of classes. I’m not 100% sure that using
nn.CrossEntropyLoss is correct for the discriminator either, but it seems to work fine thus far
This is my current (initial) thinking for the generator’s custom loss:
for loop attempts to be the equivalent to:
(sigma k=1 up to k) ((1/K)log(Dc(ck|G(z)) + (1 − (1/K)log(1 − Dc(ck|G(z)),
where Dc(c|x) is the function that returns a class c given an input image x.
# y_dim = number of classes # disc_class_layer = fully connected layer that outputs a style/class given an input image class CanGLoss(nn.Module): def __init__(self,y_dim,labels,disc_class_layer): super(CanGLoss,self).__init__() def forward(self,inp): style_loss = 0 for i in range(1,y_dim+1): style_loss += (1/i)*torch.log(disc_class_layer(inp)) + (1 - (1/i))*torch.log(1-disc_class_layer(inp)) return style_loss*-1 #multiply by -1 as is subtracted from normal GAN loss
I am new to custom loss functions and Pytorch and not sure this is the way to go. Is this on the right track? Using a
for loop seems awkward, but am not sure how else to capture the sigma 1 to K.
Any help would be great!