Delay DataParallel grad gathering

Hi, we’re currently training a huge model to combine language and vision end to end, this implies that we cannot set a Batch Size larger than 1 due to memory constraints. However, we have 4 GPUs to parallelize the model on, and we’d like to use them in parallel to decrease the total training time.

After reimplementing the scatter function to give each GPU an image (Different sizes) and a variable-length phrase, we’ve found that the communication costs are larger than the actual forward pass through the network, as each time it is necessary to replicate around 3Gb of parameters from the host GPU to the slave ones. This implies a 6Gb transfer throughout all the process, certainly a bottleneck compared to the parallel gain during the forward pass.

I would like to know if it possible to delay the gathering of the model parameters, such that each GPU can save and average gradients after forwarding 100 examples and then sending them back to the host GPU to accomplish the global optimization process, instead of doing it per each example case.

Thanks for your help!