Hi @yuqli! Yep, running the code as you have it works perfectly. If you print out your model, you can see that the whole thing is wrapped in a DataParallel
wrapper.
>>> model
DataParallel(
(module): DataParallelModel(
(block1): Linear(in_features=10, out_features=20, bias=True)
(block2): Linear(in_features=20, out_features=20, bias=True)
(block3): Linear(in_features=20, out_features=20, bias=True)
)
)
There’s some more info on DataParallel
and how it works in this forum post, where @rasbt gives a good diagram.
Hope that answers your question!
(If you’re curious as to how the inner workings of DataParallel
function, @fantasticfears has a great post on their blog).