Training the last n layers of a GAN

Hi, I am really new to all this and am working on an assignment with the DCGAN here - DCGAN Tutorial — PyTorch Tutorials 1.8.1+cu102 documentation

I have trained this GAN using a dataset A. I need to train the last N (where N=1, 2, 3 etc) layers of the already trained GAN with dataset B - how can I do this?

So far I have used nn.ModuleList() to extract all the layers and then take the last N to train it but my results are not ideal - is there a different way to do this?

Thanks in advance!

Usually you would freeze the previous layers, which should not be trained, by setting their requires_grad attribute to False and I’m unsure how you are using the nn.ModuleList. Could you explain your current use case of fine tuning the last layers a bit more, please?

Sure thing!

Right now my approach is this (but I am unsure if I am correct):

  1. Train DCGAN1 with dataset A
  2. Train DCGAN2 with dataset B
  3. Decide on N aka how many layers from DCGAN1 needs to be retrained by dataset B.
  4. Create two nn.ModuleList() objects - extract from layer 1 to the penultimate layer from trained DCGAN1. Eg: if my DCGAN has 8 layers and I want to “retrain” the last 3, I would take layers 1-5 from this GAN and store it in the first list.
  5. In the second list, take the last N layers from the trained DCGAN2 and store them.
  6. Combine both lists and declare a new Generator and Discriminator by iterating over this combined list.

Hope this makes sense.

In your answer, do you mean that if I have a trained DCGAN on say the MNIST dataset and I want to retrain the last N layers on the EMNIST dataset I could freeze the prev layers by setting requires_grad to False and then retrain this “frozen” GAN on the EMNIST set?

I apologize for the lengthy response, thank you for responding to my query - appreciate it!

Thanks for the explanation.
I’m still unsure why you are rewrapping the layers in two nn.ModuleLists. I assume you are trying to use one nn.ModuleList for the “frozen” layers and another one for the ones, which should be fine tuned?
Rewrapping the model would not be necessary, as it can break the architecture (functional API calls might be missing), and you could just manipulate the .requires_grad attribute to freeze parameters as mentioned before. Here is a small example:

for param in model.layer_to_freeze.parameters():
    param.requires_grad = False

where layer_to_freeze should be replaced with the layer names in your model.
In case you are using e.g. nn.Sequential modules internally, you could of course use them as parameters() will return all parameters recursively.

Okay that makes a lot of sense, thank you!

I was using the module lists because I was not sure how to use the freezing technique and this was my workaround for it since I was working against a deadline - so my previous explanation was just my alternate approach without using the freezing option. I did get some results from it but as you said I probably broke something and my results were not ideal.

Although, I apologize since I am not sure what you mean by this:

In case you are using e.g. nn.Sequential modules internally, you could of course use them as parameters() will return all parameters recursively.

I only meant to say that you could call the .parameters() methods on “submodules” in case they are used, which could avoid calling each layer separately, but that’s just a minor suggestion.

Thanks, I understand now! Appreciate your help!

Hi, so when I execute the method you recommended here, my code takes a REALLY long time to run - I am trying to freeze only the first few layers of the GAN from here: DCGAN Tutorial — PyTorch Tutorials 1.8.1+cu102 documentation but even freezing the first 3 layers is taking about an hour.

Is this expected?

No, that’s not expected. Could you post an executable code snippet, to reproduce the ~hour long “freezing”?

I actually think I fixed it and am now using this code snippet to freeze the first six layers

G_grandchild_count = 0


for child in netG.children():
    for grandchild in child.children():

        G_grandchild_count += 1

        print(G_grandchild_count)

        for param in grandchild.parameters():

            param.requires_grad =False

        print("Just finished setting these params to zero")

        if G_grandchild_count == 6:

            break