Hello,
I recently learned PyTorch and was trying to recreate a DCGAN that I coded up in TensorFlow. However, I was having issues with the performance. In the TensorFlow model, it was able to generate reasonable looking MNIST images and CELEBA faces. However, for the same model in PyTorch, the loss for discriminator converges to 0 and the loss for generator steadily increases.
Now, the model recreated in PyTorch migh’t not be exactly the same as some of the functions in PyTorch is slightly different than TensorFlow. For instance, for convolution:
TensorFlow:
tf.layers.conv2d(relu1, 256, 5, strides=2, padding='same')
tf.layers.conv2d_transpose(x1,128,5,strides=2, padding='same')
PyTorch:
nn.Conv2d(128,256, 5 , 2, 2)
nn.ConvTranspose2d(256,128,5,2,2)
Similarly for Batch Normalization
TensorFlow:
tf.layers.batch_normalization(x1,training=is_train)
PyTorch:
nn.BatchNorm2d(256,momentum=0.99,eps=0.001)
As you can see, I copied over the default TensorFlow parameters to PyTorch. I am also aware that channel axis is different between PyTorch and TensorFlow. I even permuted the inputs and outputs to match the behaviour of TensorFlow. However, nothing seems to be helping.
I also thought that maybe it is the different model initialization. Below are the generated images from an untrained network (it is a grid of images):
PyTorch (manual Xavier Uniform Initialization)
It seems like the default initialization for PyTorch is very different from TensorFlow. What kind of initialization does PyTorch use?
Even when using Xaiver Initialization though, I still get the same behavior. Increasing the learning rate of generator does prevent the loss of discriminator decaying to 0 and the loss of generator increasing, but the model fails to learn anything useful.
PyTorch
TensorFlow (ran with fewer iterations)
Does anyone know what could be causing this?
Here are the loss of discriminator and generator:
PyTorch (broken)
https://pastebin.com/Mqjjrzmd
TensorFlow (correct)
https://pastebin.com/xGuALVT2
Here are the code for the model and training:
PyTorch (broken)
https://pastebin.com/C5CMePYB
TensorFlow (correct)
https://pastebin.com/aWxABTCa
Thank you very much for your time and help!