I guess the default device (GPU0) might run out of memory, as your work flow seems to be close to nn.DataParallel as described here.
nn.DataParallel
Could you try to use the recommend use case of one device/replica per DDP process as described here?