I have a relatively large GAN (conditioned) I am training on a machine with multiple GPUs. I was using DataParallel in a straight forward manner: Making each of my generator and discriminator a data-paralleled model. This works, trains fine and everything as expected. However, I am looking to improve runtimes, and specifically, I ran into the uneven memory usage across different GPUs, so read the following threads among others (can only put 2 links as a new user): 1, 2
So what I wanted to do in essence is to push the loss calculation into the models. However, since I am dealing with a GAN, the discriminator requires the output of the generator and the generator loss requires calling the discriminator. So it is all intertwined, and I do not want any of the models to return their outputs to the caller, as those will be stored on GPU 0 and will cause the uneven memory distribution.
So what I did instead is to put both generator and discriminator into one model. The only complication here is that I want to update the discriminator before calling the generator, so I have to issue two calls to the joint model. The first call goes fine and I use it to store in the model class the generator inputs, outputs and the ground truth data used for supervision. The problem is that when I issue the 2nd call to the model, the class variables I added during the first call are not there anymore, as if the model has been re-instantiated. The error I get is:
AttributeError: 'TrainingIteration' object has no attribute 'g_out'
BTW, looking at nvidia-smi it seems this approach indeed resolves the uneven memory issue.
I created a simplified illustrative version of my code (it won’t run as-is, so excuse me if there are some bugs created by the simplification) and would appreciate any help here: I would like to understand what’s the mechanism that “steals” my variables away, but if anyone has another solution to what I am trying to do, that’s welcome as well (of course). Thanks in advance for any help with this.
Here is the code:
import torch.nn as nn import torch.optim as optim class TrainingIteration(nn.Module): def __init__(self, config, g_model, d_model): super(TrainingIteration, self).__init__() self.config = config self.g_model = g_model self.d_model = d_model def forward(self, *args): nargin = len(args) if nargin > 0: # discriminator training return self.forward_discr(*args) else: # generator training return self.forward_gen() def forward_discr(self, inputs, target): self.inputs = inputs self.target = target self.g_out = self.g_model(self.inputs) # discr_lossfunc calls d_model twice d_loss = self.discr_lossfunc(self.g_out.detach(), self.target) return d_loss def forward_gen(self): D_fake = self.d_model(self.g_out) g_loss = self.gen_lossfunc(self.g_out, self.target, D_fake) return g_loss class ModelTrainer: def __init__(self, config, g_model, d_model): self.config = config self.g_model = g_model self.d_model = d_model self.num_iter = 0 # initialize TrainIteration self.train_iter = TrainingIteration(self.config, self.g_model, self.d_model) self.train_iter.cuda() self.train_iter = nn.DataParallel(self.train_iter, device_ids=self.config.cuda_device_num) # Other initializations .... # not sure this will work, I might have to use self.train_iter.g_model self.g_optimizer = optim.Adam(self.g_model.parameters(), self.config.lr, [self.config.beta1, self.config.beta2]) self.d_optimizer = optim.Adam(self.d_model.parameters(), self.config.d_lr, [self.config.beta1, self.config.beta2]) self.train_dataloader = ... def train(self): # ============================================ # Training loop # ============================================ # train the network for epoch in range(self.config.num_epochs): for batch_i, sample_batched in enumerate(self.train_dataloader): self.num_iter += 1 # counting the overall iteration number # ============================================ # parse the inputs # ============================================ inputs, target = sample_batched self.d_optimizer.zero_grad() # First call d_loss = self.train_iter(inputs, target) d_loss = d_loss.mean() d_loss.backward() # perform parameter update self.d_optimizer.step() # ============================================ # TRAIN THE GENERATOR # ============================================ # clear gradients. self.g_optimizer.zero_grad() # Second call - fails g_loss = self.train_iter() g_loss = g_loss.mean() # perform backprop g_loss.backward() # perform parameter update self.g_optimizer.step()