Multi GPU segmentation using DataParallel - how to aggregate

I’m training a segmentation network on multiple GPUs and I’m using DataParallel.

My Dataset returns a dict of multiple tesnsors. e.g {“image”:…, “segmentation”:…, “weight_map”:…, …].

and on my main training loop (get batch, send to module, calculate loss, backward(), print logs,…) I want to have access to all the information generated.

e.g - when calculating the loss, I’m using the “weight_map”, and timely I want to save a random image of the current batch externally for illustration.

Option 1 is to send the entire dict to the Module, let it calculate the segmentation, calculate the loss and return just the final value.

Option 2 is to send just the image to the Module, take back the calculated segmentation and calculate the loss, and everything else on the main loop…

I’m not sure if I’m using these modules the way they were intended to be used, something here feels strange (batching dicts, doing work on the main loop, …)

Option 2 would be the standard approach, i.e. the loss would be calculated on the default device.
This also means that your default device will accumulate more memory then the other GPUs due to the target, loss etc. being stored there.
Option 1 is described in @Thomas_Wolf’s blog post.

Your approach sounds valid to me. Batching dicts should also work without a problem.