Internal model variables not saved between model calls when using Dataparallel

Hi community!
I have a relatively large GAN (conditioned) I am training on a machine with multiple GPUs. I was using DataParallel in a straight forward manner: Making each of my generator and discriminator a data-paralleled model. This works, trains fine and everything as expected. However, I am looking to improve runtimes, and specifically, I ran into the uneven memory usage across different GPUs, so read the following threads among others (can only put 2 links as a new user): 1, 2
So what I wanted to do in essence is to push the loss calculation into the models. However, since I am dealing with a GAN, the discriminator requires the output of the generator and the generator loss requires calling the discriminator. So it is all intertwined, and I do not want any of the models to return their outputs to the caller, as those will be stored on GPU 0 and will cause the uneven memory distribution.

So what I did instead is to put both generator and discriminator into one model. The only complication here is that I want to update the discriminator before calling the generator, so I have to issue two calls to the joint model. The first call goes fine and I use it to store in the model class the generator inputs, outputs and the ground truth data used for supervision. The problem is that when I issue the 2nd call to the model, the class variables I added during the first call are not there anymore, as if the model has been re-instantiated. The error I get is:

AttributeError: 'TrainingIteration' object has no attribute 'g_out'

BTW, looking at nvidia-smi it seems this approach indeed resolves the uneven memory issue.

I created a simplified illustrative version of my code (it won’t run as-is, so excuse me if there are some bugs created by the simplification) and would appreciate any help here: I would like to understand what’s the mechanism that “steals” my variables away, but if anyone has another solution to what I am trying to do, that’s welcome as well (of course). Thanks in advance for any help with this.

Here is the code:

import torch.nn as nn
import torch.optim as optim

class TrainingIteration(nn.Module):
    def __init__(self, config, g_model, d_model):
        super(TrainingIteration, self).__init__()
        self.config = config
        self.g_model = g_model
        self.d_model = d_model

    def forward(self, *args):
        nargin = len(args)
        if nargin > 0: # discriminator training
            return self.forward_discr(*args)
        else: # generator training
            return self.forward_gen()

    def forward_discr(self, inputs, target):
        self.inputs = inputs = target

        self.g_out = self.g_model(self.inputs)
        # discr_lossfunc calls d_model twice
        d_loss = self.discr_lossfunc(self.g_out.detach(),
        return d_loss

    def forward_gen(self):
        D_fake = self.d_model(self.g_out)
        g_loss = self.gen_lossfunc(self.g_out,, D_fake)
        return g_loss

class ModelTrainer:
    def __init__(self, config, g_model, d_model):
        self.config = config
        self.g_model = g_model
        self.d_model = d_model

        self.num_iter = 0
        # initialize TrainIteration
        self.train_iter = TrainingIteration(self.config, self.g_model, self.d_model)

        self.train_iter = nn.DataParallel(self.train_iter, device_ids=self.config.cuda_device_num)
        # Other initializations ....
        # not sure this will work, I might have to use self.train_iter.g_model
        self.g_optimizer = optim.Adam(self.g_model.parameters(),, [self.config.beta1, self.config.beta2])
        self.d_optimizer = optim.Adam(self.d_model.parameters(), self.config.d_lr, [self.config.beta1, self.config.beta2])

        self.train_dataloader = ...

    def train(self):
        # ============================================
        # Training loop
        # ============================================
        # train the network
        for epoch in range(self.config.num_epochs):

            for batch_i, sample_batched in enumerate(self.train_dataloader):

                self.num_iter += 1 # counting the overall iteration number
                # ============================================
                # parse the inputs
                # ============================================
                inputs, target = sample_batched
                # First call
                d_loss = self.train_iter(inputs, target)
                d_loss = d_loss.mean()
                # perform parameter update

                # ============================================
                #            TRAIN THE GENERATOR
                # ============================================
                # clear gradients.
                # Second call - fails
                g_loss = self.train_iter()
                g_loss = g_loss.mean()
                # perform backprop
                # perform parameter update

@smth, @albanD, ptrblck - I read your posts a lot and it seems like there isn’t anyting about pytorch you don’t know - Perhaps you can take a look?
I would think this should have a lot of interest from the community, as I expect anyone training GANs to run into similar problems.

@ptrblk - couldn’t reference you in the previous post due to new user limitations.
Thanks again.

self.g_out won’t be gathered, as it’s newly initialized in the replica of your original model.
You could try to initialize it in the __init__ with an empty tensor and assign the output of your generator during the forward pass.
Let me know, if that would work for you.

Thanks for the suggestion. Tried to implement this, but unfortunately it doesn’t seem to work.
Initialized self.inputs, and self.g_out to either torch.empty() or torch.zeros(). Tried size of 1 (arbitrarily) and also the true size of those tensors (used as batch size the full batch size divided by the number of GPUs). However, in the second call g_loss = self.train_iter() I get those tensors reset, as if the constructor has been called again (it hasn’t, I added a printout in TrainingIteration.__init__).
To clarify, at the end of forward_discr() I print some value from self.g_out and I print the same value at the beginning of forward_gen(), expecting (or hoping) it would be the same, but it becomes either zero or a random number, depending on whether I initialize self.g_out to torch.zeros() or to torch.empty().
So bottom line, no success so far…
In case there are any additional suggestions as to what can be done, I’d appreciate that. Also, if there is any in-depth material as to how DataParallel works under the hood, that may shed some light on the issue.


Was anyone able to figure out an answer to this issue? I’m experiencing the same problem (for a very unrelated reason). I am trying to save a variable during the forward pass of the network and access it later on. This works fine on 1 GPU, but causes issues on multi-gpu.

Not at my computer, but loosely describing, you need to use register_buffer() to create this variable, then it will be magically copied between gpus. Only registered buffers and nn.Parameters enjoy this behavior.
Hope this helps.