Get batch's datapoints across all GPUs


I´m running my model (using pytorch lightning) in a cluster with multiples GPUs (2). My problem is that I would like to access all the datapoints in the batch. Because I´m using more than 2 GPUs, my batch in divided between those two devices for parallelisation purposes, which means than when I access the data in the batch in eval/training, I´m getting just half the batch.

How could I obtain the complete batch and the predictions of the model that are divided among different devices/GPUs?


Are you using nn.DataParallel or DistributedDataParallel?
In the former (not recommended) case the batch would be split while passing it to the DataParallel model, so inside the forward of the models you would only see a chunk of the data. The output however is gathered again on the default device and you will see the output for the entire batch.
In the latter case, each process would use a single GPU so would be responsible to load the entire batch.

Actually, I´m not quite sure of what lightning is using. I tried to set the flag accelerator=“ddp” on the Trainer but the problem persists.

Could I force “ddp_model = DDP(model, device_ids=[rank])” over my model or may that results in conflict in lightning?

I’m not familiar enough with Lightning so don’t know what it uses internally or how it would interact if you manually wrap the model into e.g. DDP.