DataParallel imbalanced memory usage

Well the op you cited is right, in the end pytorch collect the whole batch output in one gpu to calculate the loss. That’s what generates the imbalanced usage. At the same time optimizers’ parameters are stored in by-default gpu, what makes the problem worse. There is no apparent solution since dataparallel allows you to choose output gpu but the problem remains. There are several open threads asking about the same issue. I tried to play with dataparallel arguments: device_ids and output_device but no way.