Hi community!
I have a relatively large GAN (conditioned) I am training on a machine with multiple GPUs. I was using DataParallel in a straight forward manner: Making each of my generator and discriminator a data-paralleled model. This works, trains fine and everything as expected. However, I am looking to improve runtimes, and specifically, I ran into the uneven memory usage across different GPUs, so read the following threads among others (can only put 2 links as a new user): 1, 2
So what I wanted to do in essence is to push the loss calculation into the models. However, since I am dealing with a GAN, the discriminator requires the output of the generator and the generator loss requires calling the discriminator. So it is all intertwined, and I do not want any of the models to return their outputs to the caller, as those will be stored on GPU 0 and will cause the uneven memory distribution.
So what I did instead is to put both generator and discriminator into one model. The only complication here is that I want to update the discriminator before calling the generator, so I have to issue two calls to the joint model. The first call goes fine and I use it to store in the model class the generator inputs, outputs and the ground truth data used for supervision. The problem is that when I issue the 2nd call to the model, the class variables I added during the first call are not there anymore, as if the model has been re-instantiated. The error I get is:
AttributeError: 'TrainingIteration' object has no attribute 'g_out'
BTW, looking at nvidia-smi it seems this approach indeed resolves the uneven memory issue.
I created a simplified illustrative version of my code (it won’t run as-is, so excuse me if there are some bugs created by the simplification) and would appreciate any help here: I would like to understand what’s the mechanism that “steals” my variables away, but if anyone has another solution to what I am trying to do, that’s welcome as well (of course). Thanks in advance for any help with this.
Here is the code:
import torch.nn as nn
import torch.optim as optim
class TrainingIteration(nn.Module):
def __init__(self, config, g_model, d_model):
super(TrainingIteration, self).__init__()
self.config = config
self.g_model = g_model
self.d_model = d_model
def forward(self, *args):
nargin = len(args)
if nargin > 0: # discriminator training
return self.forward_discr(*args)
else: # generator training
return self.forward_gen()
def forward_discr(self, inputs, target):
self.inputs = inputs
self.target = target
self.g_out = self.g_model(self.inputs)
# discr_lossfunc calls d_model twice
d_loss = self.discr_lossfunc(self.g_out.detach(), self.target)
return d_loss
def forward_gen(self):
D_fake = self.d_model(self.g_out)
g_loss = self.gen_lossfunc(self.g_out, self.target, D_fake)
return g_loss
class ModelTrainer:
def __init__(self, config, g_model, d_model):
self.config = config
self.g_model = g_model
self.d_model = d_model
self.num_iter = 0
# initialize TrainIteration
self.train_iter = TrainingIteration(self.config, self.g_model, self.d_model)
self.train_iter.cuda()
self.train_iter = nn.DataParallel(self.train_iter, device_ids=self.config.cuda_device_num)
# Other initializations ....
# not sure this will work, I might have to use self.train_iter.g_model
self.g_optimizer = optim.Adam(self.g_model.parameters(), self.config.lr, [self.config.beta1, self.config.beta2])
self.d_optimizer = optim.Adam(self.d_model.parameters(), self.config.d_lr, [self.config.beta1, self.config.beta2])
self.train_dataloader = ...
def train(self):
# ============================================
# Training loop
# ============================================
# train the network
for epoch in range(self.config.num_epochs):
for batch_i, sample_batched in enumerate(self.train_dataloader):
self.num_iter += 1 # counting the overall iteration number
# ============================================
# parse the inputs
# ============================================
inputs, target = sample_batched
self.d_optimizer.zero_grad()
# First call
d_loss = self.train_iter(inputs, target)
d_loss = d_loss.mean()
d_loss.backward()
# perform parameter update
self.d_optimizer.step()
# ============================================
# TRAIN THE GENERATOR
# ============================================
# clear gradients.
self.g_optimizer.zero_grad()
# Second call - fails
g_loss = self.train_iter()
g_loss = g_loss.mean()
# perform backprop
g_loss.backward()
# perform parameter update
self.g_optimizer.step()