Torch.nn.DataParallel is working weird. This might be a bug or I have hard time understanding the issue

So I have a code that is much complex. Given that, I tried to recreate the issue using the official example of pytorch on nn.DataParallel. I am trying to have a forward function that goes two routes depending on the value of route variable. if route is True it does normal forwarding using a Linear function. the output of function is given to out which is an object of the Class. If route is False then I just ask it to print the shape of out. Surprisingly, the model thinks the out is still None, which is weird. So either I do not understand how DataParallel works or this is a serious issue. I appreciate any insight on this.

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1"

input_size = 5
output_size = 2

batch_size = 30
data_size = 100

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)

class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.out = None

    def forward(self, x):
        input = x[0]
        route = x[1]

        if route:
            output = self.fc(input)
            if self.out is None:
                self.out = output
            else:
                torch.cat((self.out, output), 1)
            return output
        else:
            print(self.out.shape)
            return self.out.shape
        
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)
model.to(device)
    
for data in rand_loader:
    input = data.to(device)
    output = model((input, True))
    print("Outside: input size", input.size(),
          "output_size", output.size())
    model((input, False))

Here is the error:

AttributeError: Caught AttributeError in replica 0 on device 0.

Original Traceback (most recent call last):

  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker

    output = module(*input, **kwargs)

  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl

    return forward_call(*input, **kwargs)

  File "/tmp/ipykernel_3007/2423608035.py", line 49, in forward

    print(self.out.shape)

AttributeError: 'NoneType' object has no attribute 'shape'

I’m not quite sure why this happened. Can you try passing the two arguments individually?
like def forward (self, inputx, route)

nn.DataParallel deals with the input arguments in different ways up to the data type. Wrapping a tensor inside tuple might cause some unexpected error.

That still gives me the same error.

So the problem might be you’re trying to modify the module attribute inside the forward function. See the second warning notes
https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#dataparallel

you update the self.out in the forward function, which would be lost after the function invocation.

That makes sense. is there any hack around this?

Add one more argument last_output to the forward function could make it.

yeah, i thought about the same. I wish it was easier than that though. Thanks.