Method overwrite for combining discriminator and generator into one NN module

I have some main function that runs forward, backward on models that I want to adapt such that I can use one model file for a GAN framework (two models: discriminator and generator).

With one model, one can call model.train() but we two models there’s a TypeError because two arguments are passed (discriminator and generator)

Is there a way to circumvent this such that I don’t have to put conditions on my main function?

Can’t help you unless you put down your code here and error here

You’re supposed to run them in two different steps. First the generator runs, then the discriminator. Something like

for phase in ['generator','discriminator']:
    if phase == 'generator':
        generator.train()
    else:
        discriminator.train()

In code, this is very much like training a supervised framework where you run steps for training and validation.

I suppose Generator and Discriminator models are two separate classes with init, forward methods have you tried adding .train() as a method to the class

Yes, but if one calls torch.nn.DataParalel on the class it overwrites the train method.

Yes, but I want to do that from within the class that holds both the discriminator and generator such that I can still call model.train() and let the class handle the phase, etc.