nn.Parallel wraps the whole model?

Hi @yuqli! Yep, running the code as you have it works perfectly. If you print out your model, you can see that the whole thing is wrapped in a DataParallel wrapper.

>>> model

DataParallel(
  (module): DataParallelModel(
    (block1): Linear(in_features=10, out_features=20, bias=True)
    (block2): Linear(in_features=20, out_features=20, bias=True)
    (block3): Linear(in_features=20, out_features=20, bias=True)
  )
)

There’s some more info on DataParallel and how it works in this forum post, where @rasbt gives a good diagram.

Hope that answers your question!

(If you’re curious as to how the inner workings of DataParallel function, @fantasticfears has a great post on their blog).

1 Like