nn.Parallel wraps the whole model?

I’m rather confused after reading both official tutorials on multi-GPU and data-parallelism. Can I wrap the whole model in nn.Parallel, instead of one layer at a time?

e.g. is the following code block legitimate?

import torch
import torch.nn as nn


class DataParallelModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)
        self.block2 = nn.Linear(20, 20)
        self.block3 = nn.Linear(20, 20)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

model = DataParallelModel()
model = nn.DataParallel(model)

Thanks!

1 Like

Hi @yuqli! Yep, running the code as you have it works perfectly. If you print out your model, you can see that the whole thing is wrapped in a DataParallel wrapper.

>>> model

DataParallel(
  (module): DataParallelModel(
    (block1): Linear(in_features=10, out_features=20, bias=True)
    (block2): Linear(in_features=20, out_features=20, bias=True)
    (block3): Linear(in_features=20, out_features=20, bias=True)
  )
)

There’s some more info on DataParallel and how it works in this forum post, where @rasbt gives a good diagram.

Hope that answers your question!

(If you’re curious as to how the inner workings of DataParallel function, @fantasticfears has a great post on their blog).

1 Like