Training network with multiple outputs with multi gpus

Hi,

I have a model that has multiple outputs (20 outputs), and each of them is a linear classifier of 4-10 classes. In order to enable the model returns these 20 outputs could be:

# for example...
def Model(nn.Modules): 
    def __init__(self):
        ...
        self.features = nn.Sequential(...)
        # repeat the code N times
        self.classifier_1 = nn.Linear(1000, 7)
        self.classifier_2 = nn.Linear(1000, 4)
        ...
        self.classifier_N = nn.Linear(1000, 10)

    def forward(self, x):
        x = self.features(x)
        # repeat the code N times
        y1 = self.classifier_1(x)
        y2 = self.classifier_2(x)
        ...
        yN = self.classifier_N(x)
        return y1, y2, ..., yN

This approach works with single/multiple gpus.

However, repeating the classifier code self.classifier_X N times is not elegant neither dynamic/flexible. Here a new solution:

def Model(nn.Modules): 
    def __init__(self):
        ...
        self.features = nn.Sequential(...)
        self.nClassesPerClassifier = [7, 4, ..., 10]
        # create N classifiers
        for c in range(len(self.nClassesPerClassifier)):
            self.__setattr__('class_%d' % c,
                                nn.Linear(1000, nClassesPerClassifier[c]))

    def forward(self, x):
        x = self.features(x)
        # returns a dictionary whose key indicates the classN
        return {'class_%d' % c: self.__getattr__('class_%d' % c)(x) 
            for c in range(len(self.nClassesPerClassifier))}

This approach works for single GPU, but I got an error by using multiple GPU:

*** RuntimeError: maximum recursion depth exceeded while calling a Python object

The code of above works if returns a list or tuple of N Variables.

Does anybody know other solution to handle with multiple output with multiple gpus?

solved by:

def Model(nn.Modules): 
    def __init__(self):
        ...
        self.features = nn.Sequential(...)

        self.nClassesPerClassifier = [7, 4, ..., 10]
        # create N classifiers
        for c in range(len(self.nClassesPerClassifier)):
            self.__setattr__('class_%d' % c,
                                nn.Linear(1000, nClassesPerClassifier[c]))

    def forward(self, x):
        x = self.features(x)  # here uses multi gpu
        # returns a dictionary whose key indicates the classN
        return {'class_%d' % c: self.__getattr__('class_%d' % c)(x) 
            for c in range(len(self.nClassesPerClassifier))}


    def set_multiple_gpus(self):
        # here uses multi gpu
        self.features = nn.DataParalell(self.features).cuda()

model = Model()
model.load_dict_state(...)
model.set_multiple_gpus()
...
Y = model(inputs)

The problem was that forward() returned a dictionary {key: Variable} and this can be handled by single gpu or cpu, but cant be handle by torch.nn.DataParallel (multiple gpu). Only a list or tuple can be handled by torch.nn.DataParallel. However, pytorch allows you to specify which parts (layers) of your network are performed with cpu/gpu/multiple-gpu. So, I specified that the feature extraction part must be performed with multiple gpus. Please, check set_multiple_gpus() method. Moreover, if you load a pre trained model, then you have to load it before calling set_multiple_gpus().

Cheers

2 Likes

it’d be better if you mentioned how you solved it, instead of people reading through the code to check the change…

I just have exactly same situation, thanks for the solution. Since I want to make the portion of the network that are made multigpu as much as possible, finally I changed the torch/nn/parallel/scatter_gather.py 's gather() function to be :

def gather(outputs, target_device, dim=0):
    """
    Gathers variables from different GPUs on a specified device
      (-1 means the CPU).
    """
    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, Variable):
            return Gather(target_device, dim=dim)(*outputs)
        if isinstance(out, OrderedDict):
            return OrderedDict([
                (k, Gather(target_device, dim=dim)(*(each[k] for each in outputs))) \
                        for k in out.iterkeys()])
        if out is None:
            return None
        return type(out)(map(gather_map, zip(*outputs)))
    return gather_map(outputs)

and in the network, during forwarding, it returns an OrderedDict, .e.g.

def forward(self, x):
    x = self.features(x)
    x_emb = self.emb(x)  # normal embedding head
    # below are small classification heads for all types of attributes
    x_attr_dict = OrderedDict([(attr_name, mod(x)) for attr_name, mod in self.attr_classifiers.named_children()])
    return x, x_attr_dict

in this way, the whole network can be make multi gpu version, the x_attr_dict 's result will also be gathered properly though it is an OrderedDict with keys are attribute names and values are features for each attribute.

3 Likes

Thanks for sharing! Was about to implement exactly this and then saw your reply :slight_smile:
Did you consider making a PR for this?

Also, had to change your code a bit to make it work for me:

if isinstance(out, OrderedDict):
    return OrderedDict( [(k,Gather.apply(target_device, dim, *[each[k] for each in outputs])) for k in out.keys()] )

Havn’t thought about submitting a PR. My modification is for pytorch0.2, after 0.3 i should use your modification.

ok, do you want me to do the PR?

sure please do it :slight_smile:

1 Like