Derivative for aten::batch_norm_backward_elemt is not implemented

Hi Guys,

We are having trouble calculating gradient penalty in distributed training.

Here is a simple sample code : -

​
import argparse
import os
import sys
import tempfile
from urllib.parse import urlparse
​
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from simplerGAN import EncoderDecoder
import torchvision.transforms as transforms 
​
from torch.nn.parallel import DistributedDataParallel as DDP
​
​
def make_r1_gp(discr_real_pred, real_batch,use_ddp = False):
   #This function only works if torch.is_grad_enabled():
   #We have set it to False to force no calculation of grad_penalty.
   #After a solution of calculating grad penalty, please change it back.
   if torch.is_grad_enabled() and not use_ddp:
       grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(),
                                       inputs=real_batch,
                                       create_graph=True)[0]
       grad_penalty = (grad_real.view(grad_real.shape[0],-1).norm(2, dim=1)**2).mean()
​
   else:
       grad_penalty = 0
   real_batch.requires_grad = False
​
   return grad_penalty
​
​
def demo_basic(local_world_size, local_rank):
​
   # setup devices for this process. For local_world_size = 2, num_gpus = 8,
   # rank 0 uses GPUs [0, 1, 2, 3] and
   # rank 1 uses GPUs [4, 5, 6, 7].
   transform=transforms.Compose(
           [transforms.Resize(254), transforms.Normalize([0.5], [0.5])]
       )
   n = torch.cuda.device_count() // local_world_size
   device_ids = list(range(local_rank * n, (local_rank + 1) * n))
​
   print(
       f"[{os.getpid()}] rank = {dist.get_rank()}, "
       + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end=''
   )
   model = EncoderDecoder().cuda(device_ids[0])
   ddp_model = DDP(model, device_ids)
   loss_fn = nn.MSELoss()
   optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
​
   optimizer.zero_grad()
   inp = transform(torch.randn(1,3, 254,254).to(device_ids[0]))
   inp.requires_grad = True
   outputs = ddp_model(inp)
   loss = loss_fn(outputs, inp) 
   loss += make_r1_gp(outputs,inp)
   loss.backward()
   optimizer.step()
​
​
def spmd_main(local_world_size, local_rank):
   # These are the parameters used to initialize the process group
   env_dict = {
       key: os.environ[key]
       for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
   }
   
   if sys.platform == "win32":
       # Distributed package only covers collective communications with Gloo
       # backend and FileStore on Windows platform. Set init_method parameter
       # in init_process_group to a local file.
       if "INIT_METHOD" in os.environ.keys():
           print(f"init_method is {os.environ['INIT_METHOD']}")
           url_obj = urlparse(os.environ["INIT_METHOD"])
           if url_obj.scheme.lower() != "file":
               raise ValueError("Windows only supports FileStore")
           else:
               init_method = os.environ["INIT_METHOD"]
       else:
           # It is a example application, For convience, we create a file in temp dir.
           temp_dir = tempfile.gettempdir()
           init_method = f"file:///{os.path.join(temp_dir, 'ddp_example')}"
       dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]), world_size=int(env_dict["WORLD_SIZE"]))
   else:
       print(f"[{os.getpid()}] Initializing process group with: {env_dict}")  
       dist.init_process_group(backend="nccl")
​
   print(
       f"[{os.getpid()}]: world_size = {dist.get_world_size()}, "
       + f"rank = {dist.get_rank()}, backend={dist.get_backend()} \n", end=''
   )
​
   demo_basic(local_world_size, local_rank)
​
   # Tear down the process group
   dist.destroy_process_group()
​
​
if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   # This is passed in via launch.py
   parser.add_argument("--local_rank", type=int, default=0)
   # This needs to be explicitly passed in
   parser.add_argument("--local_world_size", type=int, default=1)
   args = parser.parse_args()
   # The main entry point is called directly without using subprocess
   spmd_main(args.local_world_size, args.local_rank)



We have created a simple model as you can see here : -

import torch
import torch.nn as nn
​
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.SyncBatchNorm(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.SyncBatchNorm(32),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
​
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

When running the training “demo_basic” function without the distributed training, the code works fine but when we launch this in distributed training it starts throwing the following error :

" derivative for aten::batch_norm_backward_elemt is not implemented ".

How can we get a fix around this?