I have some main function that runs forward, backward on models that I want to adapt such that I can use one model file for a GAN framework (two models: discriminator and generator).
With one model, one can call model.train() but we two models there’s a TypeError because two arguments are passed (discriminator and generator)
Is there a way to circumvent this such that I don’t have to put conditions on my main function?
Can’t help you unless you put down your code here and error here
You’re supposed to run them in two different steps. First the generator runs, then the discriminator. Something like
for phase in ['generator','discriminator']:
if phase == 'generator':
In code, this is very much like training a supervised framework where you run steps for training and validation.
I suppose Generator and Discriminator models are two separate classes with init, forward methods have you tried adding .train() as a method to the class
Yes, but if one calls torch.nn.DataParalel on the class it overwrites the train method.
Yes, but I want to do that from within the class that holds both the discriminator and generator such that I can still call model.train() and let the class handle the phase, etc.