Per-process backward() call

I have a question about independent backward passes in each process in Pytorch’s distributed framework.

Suppose I spawn n worker processes with mp.spawn(), and then in each worker process, I call dist.init_process_group() with its rank and a URL (e.g. tcp:// to initialize it to a port on the local host).

I’m parallelizing the model either with
torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[rank])

When I call loss.backward() on my network’s loss function, I believe the gradients are being averaged from all gpus in the process group (all_reduce sort of operation, so each gpu gets the result).

Suppose I wish for each process to ignore all other gpus and only call backward() with respect to its data subset. I wish to use these process-dependent gradients then to update my main, shared model. Does Pytorch’s current framework support this, and if so, how?

Presumably two ways to do it would be (1) create a new process group right before the backward call, and then rejoin the joint process group immediately afterwards (2) do not parallelize the model under DistributedDataParallel, but keep a separate model at each process, and do the send/recv/reduce operations myself. But this wouldn’t support synchronized BatchNorm, most likely.