GPU Parallel problems

I want to implement a graphic ram efficient trainning programs,but get threading lock problems . here is the steps .
1 replicate my modules to 4 GPUS .
2 caculate forward and loss in this 4 GPUS.
3 caculate backward in 4 GPUS , here comes the problem the backward step cannot be parallel .
According debug ,i find when run loss.backward() in GPU1 , the modules in gpu1 which copy from default model ,it parameter’s grad is None ,and default model get the grad that gpu1 actually caculated

models[1] is copy from model

Why won’t DataParallel not work in your use case?

yeah, I first implement it use dataparallel ,but my category is more than 20000。 when pytorch gather all output to gpu0 ,GPU 0 have more graphics ram consume than others ,it limits the batch_size value,my dataset nearly 400gb large, it a very long time to wait