Have a look at this example, where I’ve created a small example using model sharding and nn.DataParallel.
nn.DataParallel