Can't figure out how to integrate DataParallel to this workflow

If a single sample yields an OOM error on a single device, nn.DataParallel won’t save you from this.
As the name suggests, a data parallel approach will be used, which will split the input batch in dim0 and send each chunk to each specified device.
If you want to use model sharding (model parallel), you could take a look at this post.