Derivative for aten::batch_norm_backward_elemt is not implemented

We are having trouble calculating gradient penalty in distributed training.

Here is a simple sample code : -

import argparse
import os
import sys
import tempfile
from urllib.parse import urlparse
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from simplerGAN import EncoderDecoder
import torchvision.transforms as transforms 
from torch.nn.parallel import DistributedDataParallel as DDP
def make_r1_gp(discr_real_pred, real_batch,use_ddp = False):
   #This function only works if torch.is_grad_enabled():
   #We have set it to False to force no calculation of grad_penalty.
   #After a solution of calculating grad penalty, please change it back.
   if torch.is_grad_enabled() and not use_ddp:
       grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(),
       grad_penalty = (grad_real.view(grad_real.shape[0],-1).norm(2, dim=1)**2).mean()
       grad_penalty = 0
   real_batch.requires_grad = False
   return grad_penalty
def demo_basic(local_world_size, local_rank):
   # setup devices for this process. For local_world_size = 2, num_gpus = 8,
   # rank 0 uses GPUs [0, 1, 2, 3] and
   # rank 1 uses GPUs [4, 5, 6, 7].
           [transforms.Resize(254), transforms.Normalize([0.5], [0.5])]
   n = torch.cuda.device_count() // local_world_size
   device_ids = list(range(local_rank * n, (local_rank + 1) * n))
       f"[{os.getpid()}] rank = {dist.get_rank()}, "
       + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end=''
   model = EncoderDecoder().cuda(device_ids[0])
   ddp_model = DDP(model, device_ids)
   loss_fn = nn.MSELoss()
   optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
   inp = transform(torch.randn(1,3, 254,254).to(device_ids[0]))
   inp.requires_grad = True
   outputs = ddp_model(inp)
   loss = loss_fn(outputs, inp) 
   loss += make_r1_gp(outputs,inp)
def spmd_main(local_world_size, local_rank):
   # These are the parameters used to initialize the process group
   env_dict = {
       key: os.environ[key]
       for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
   if sys.platform == "win32":
       # Distributed package only covers collective communications with Gloo
       # backend and FileStore on Windows platform. Set init_method parameter
       # in init_process_group to a local file.
       if "INIT_METHOD" in os.environ.keys():
           print(f"init_method is {os.environ['INIT_METHOD']}")
           url_obj = urlparse(os.environ["INIT_METHOD"])
           if url_obj.scheme.lower() != "file":
               raise ValueError("Windows only supports FileStore")
               init_method = os.environ["INIT_METHOD"]
           # It is a example application, For convience, we create a file in temp dir.
           temp_dir = tempfile.gettempdir()
           init_method = f"file:///{os.path.join(temp_dir, 'ddp_example')}"
       dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]), world_size=int(env_dict["WORLD_SIZE"]))
       print(f"[{os.getpid()}] Initializing process group with: {env_dict}")  
       f"[{os.getpid()}]: world_size = {dist.get_world_size()}, "
       + f"rank = {dist.get_rank()}, backend={dist.get_backend()} \n", end=''
   demo_basic(local_world_size, local_rank)
   # Tear down the process group
if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   # This is passed in via
   parser.add_argument("--local_rank", type=int, default=0)
   # This needs to be explicitly passed in
   parser.add_argument("--local_world_size", type=int, default=1)
   args = parser.parse_args()
   # The main entry point is called directly without using subprocess
   spmd_main(args.local_world_size, args.local_rank)

We have created a simple model as you can see here : -

import torch
import torch.nn as nn
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2)
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

When running the training “demo_basic” function without the distributed training, the code works fine but when we launch this in distributed training it starts throwing the following error :

" derivative for aten::batch_norm_backward_elemt is not implemented ".

How can we get a fix around this?

Apologies for the mistake in the comment of function “make_r1_gp”. The comment should be

@ptrblck Any suggestions?

I encoutered the same problem when turning on syncBN. If turning off it, the code works well.