I have some main function that runs forward, backward on models that I want to adapt such that I can use one model file for a GAN framework (two models: discriminator and generator).
With one model, one can call model.train() but we two models there’s a TypeError because two arguments are passed (discriminator and generator)
Is there a way to circumvent this such that I don’t have to put conditions on my main function?
I suppose Generator and Discriminator models are two separate classes with init, forward methods have you tried adding .train() as a method to the class
Yes, but I want to do that from within the class that holds both the discriminator and generator such that I can still call model.train() and let the class handle the phase, etc.