Reconstructing a model from Tensorflow


I recently learned PyTorch and was trying to recreate a DCGAN that I coded up in TensorFlow. However, I was having issues with the performance. In the TensorFlow model, it was able to generate reasonable looking MNIST images and CELEBA faces. However, for the same model in PyTorch, the loss for discriminator converges to 0 and the loss for generator steadily increases.

Now, the model recreated in PyTorch migh’t not be exactly the same as some of the functions in PyTorch is slightly different than TensorFlow. For instance, for convolution:
tf.layers.conv2d(relu1, 256, 5, strides=2, padding='same')
tf.layers.conv2d_transpose(x1,128,5,strides=2, padding='same')
nn.Conv2d(128,256, 5 , 2, 2)

Similarly for Batch Normalization

As you can see, I copied over the default TensorFlow parameters to PyTorch. I am also aware that channel axis is different between PyTorch and TensorFlow. I even permuted the inputs and outputs to match the behaviour of TensorFlow. However, nothing seems to be helping.

I also thought that maybe it is the different model initialization. Below are the generated images from an untrained network (it is a grid of images):



PyTorch (manual Xavier Uniform Initialization)

It seems like the default initialization for PyTorch is very different from TensorFlow. What kind of initialization does PyTorch use?
Even when using Xaiver Initialization though, I still get the same behavior. Increasing the learning rate of generator does prevent the loss of discriminator decaying to 0 and the loss of generator increasing, but the model fails to learn anything useful.

TensorFlow (ran with fewer iterations)

Does anyone know what could be causing this?

Here are the loss of discriminator and generator:
PyTorch (broken)
TensorFlow (correct)

Here are the code for the model and training:
PyTorch (broken)
TensorFlow (correct)

Thank you very much for your time and help!

Can’t the Pytorch DCGAN example help you?

Thank you for the response! Yes, I have looked at other examples of DCGAN. My goal isn’t to implement a DCGAN but rather implement a DCGAN that I previous implemented in TensorFlow in order to help me better understand PyTorch. My question is why the model doesn’t work when it works in TensorFlow? It is probably some PyTorch convention that I missed that doesn’t exist in TensorFlow.

Also, here are some examples of generated images after training it for a couple of seconds
It appears that the model is already suffering mode collapse.

How could this be if it was working perfectly on TensorFlow?

How does it work when using the default parameters for nn.BatchNorm2d (eps=1e-05 , momentum=0.1)?

Same issue. I changed the parameters because those are the default parameters for TensorFlow.

Also, another thing that I noticed is that, in the PyTorch implementation, the discriminator loss starts out very low. How can that be? I couldn’t see any difference in the loss functions between the 2 implementations.

There should be no large difference between nn.BCELoss() and nn.BCEWithLogitsLoss(), but how about trying to replace nn.BCEWithLogitsLoss() in your code with nn.BCELoss()? The PyTorch DCGAN example use nn.BCELoss(), but presumably its use is just for simplicity.

I tried BCELoss before and I still observed the same effect. Isn’t BCEWithLogitsLoss superior though? because of numerical stability issues with BCELoss?

You are right. I just doubted BCEWithLogitsLoss.

I did as well, but changing it still had the same outcome. I also noticed a missing leaky RELU layer in my discriminator model when compared to the TensorFlow implementation. I changed that as well as well as the alpha value for the leaky RELU to match TensorFlow’s default parameter. I am still getting the same outcome.

Some people joining in this forum cannot probably decipher the Tensorflow code. Please outline the deep learning architecture used in your Tensorflow code. So, you would get more helps.

Okay. Although my main concern isn’t with the architecture, but rather why the architecture work in TensorFlow but not PyTorch.