"derivative for aten::grid_sampler_2d_backward is not implemented" when using Adaptive Discriminator Augmentation (ada)

Hello,

I have re-implemented StyleGAN2-ada since the original implementations was giving me many troubles. I have used torchvision.transforms.functional to provide random data augmentation on real and fake images from the generator.

Unfortunately if am facing a problem, apparently at random during the training process (usually around iteration 15K). Still not able to understand how to reproduce the error I get:

Init ADA with p=0.01
 10%|████▊                                            | 14687/150000 [13:37:00<125:27:07,  3.34s/it]

Traceback (most recent call last):
  File "trainStyleGAN.py", line 581, in <module>
    main()
  File "trainStyleGAN.py", line 552, in main
    trainer.train() #start training loop
  File "trainStyleGAN.py", line 525, in train
    self.step(i)
  File "trainStyleGAN.py", line 445, in step
    gen_loss.backward()
  File "/home/alberto/GAN_venv/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/alberto/GAN_venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 197, in backward
    
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: derivative for aten::grid_sampler_2d_backward is not implemented

I cannot understand what is going on. I only now that is for sure related with ADA because if I run the code without any augmentation, no errors at all.

I add below some rows of my code. […] means I have skipped something to be short.

Augmenter Class:

class AugmentPipe(torch.nn.Module):
    def __init__(self, init_p=0,
        xflip=0, rotate90=0, xint=0, xint_max=0.125,
        [...]
    ):
        super().__init__()
        self.p = init_p

        # Pixel blitting.
        self.xflip            = float(xflip)            # Probability multiplier for x-flip.
        self.rotate90         = float(rotate90)         # Probability multiplier for 90 degree rotations.
        self.xint             = float(xint)             # Probability multiplier for integer translation.
        self.xint_max         = float(xint_max)         # Range of integer translation, relative to image dimensions.

        [...] #more transforms here I am not reporting


    def forward(self, images, masks=None, debug=False, args=None):

        random.seed()
        torch.random.seed()
    
        if debug:
            self.p = 1.0

        if self.p == 0:
           return images, masks

        assert isinstance(images, torch.Tensor) and images.ndim == 4
        self.device = images.device

        batch_size, num_channels, height, width = images.shape

        with_mask = masks is not None
        if with_mask: 
            assert isinstance(images, torch.Tensor) and masks.device == self.device and masks.ndim == 4
            batch_size_M, num_channels_M, height_M, width_M = masks.shape
        
        if(self.xflip > 0):
            apply_T = random.random()
            i = torch.randint(0,2,[images.shape[0]], device=self.device)*self.xflip * self.p > apply_T
            
            if torch.sum(i) > 0:
                args.file_log.debug(f"[{args.idx} , {args.debug_str}] xflip: {i}")
                images[i,:,:,:] = TF.hflip(images[i,:,:,:])
                if masks is not None:
                    masks[i,:,:,:] = TF.hflip(masks[i,:,:,:])

            if torch.any(images.isnan()):
                        print("flip is NaN")
        
        if(self.rotate90 > 0):
            apply_T = random.random()
            i = torch.randint(0,2,[images.shape[0]], device=self.device) * self.rotate90 * self.p > apply_T
            if torch.sum(i) > 0:
                args.file_log.debug(f"[{args.idx} , {args.debug_str}] rotate90: {i}")
                images[i,:,:,:] = TF.rotate(images[i,:,:,:], 90, torchvision.transforms.InterpolationMode.BILINEAR)
                if masks is not None:
                    masks[i,:,:,:] = TF.rotate(masks[i,:,:,:], 90, torchvision.transforms.InterpolationMode.NEAREST)
            
            if torch.any(images.isnan()):
                        print("rot90 is NaN")

        if(self.xint > 0):
            apply_T = random.random()
            i = torch.randint(0,2,[images.shape[0]], device=self.device) * self.xint * self.p > apply_T 
            t = random.random() * self.xint_max
            if torch.sum(i) > 0:
                args.file_log.debug(f"[{args.idx} , {args.debug_str}] xint: {i}")

                images[i,:,:,:] = TF.affine(images[i,:,:,:], 0, [t,t], 1, 0, torchvision.transforms.InterpolationMode.BILINEAR)
                if masks is not None:
                    masks[i,:,:,:] = TF.affine(masks[i,:,:,:], 0, [t,t], 1, 0, torchvision.transforms.InterpolationMode.NEAREST)
                
            if torch.any(images.isnan()):
                    print("xint is NaN")

[...]


        return images, masks

    def update_p(self, p):
            self.p = p

Training step:

[...]
self.augmenter = AugmentPipe(init_p=self.p, **augpipe_specs['bgcfn']).requires_grad_(False).to(self.args.device)
[...]

    def step(self, idx: int):
        """
        ### Training Step
        """

        self.args.idx = idx

        # Train the discriminator
        # Reset gradients
        self.discriminator_optimizer.zero_grad()

        # Accumulate gradients for `gradient_accumulate_steps`
        for i in range(self.args.gradient_accumulate_steps):
            
            if self.args.use_ada and idx % self.args.p_update_interval == 0:
                self.p = self.p_sched.get_p() 
                self.augmenter.update_p(self.p)
                self.logger.add_scalar("[TRAIN] ADA PROB (p)", self.p, global_step=idx)
                

            # Sample images from generator
            generated_images, _ = self.generate_images(self.args.batch_size)

            #Apply ADA
            if self.args.use_ada:
                self.args.debug_str = "DIC gen images"
                generated_images, _ = self.augmenter(generated_images, args=self.args, debug=True)

            # Discriminator classification for generated images
            fake_output = self.discriminator(generated_images.detach())
            self.avg_log["avg_out_D_fake"] += torch.mean(fake_output).item()

            # Get real images from the data loader
            real_images, patches = next(self.loader)
            real_images = real_images.to(self.args.device)

            debug_NaN("Real", real_images, self.logger, idx)

            #APPLY ADA
            real_images_original = real_images.clone()
            if self.args.use_ada:
                self.args.debug_str = "DIC real images"
                real_images, _ = self.augmenter(real_images, args=self.args, debug=True)
                debug_NaN("AUGMENTED Real", real_images, self.logger, idx)

            # We need to calculate gradients w.r.t. real images for gradient penalty
            if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
                real_images.requires_grad_()
            # Discriminator classification for real images
            real_output = self.discriminator(real_images)
            self.avg_log["avg_out_D_real"] += torch.mean(real_output).item()

            if self.args.use_ada:
                self.p_sched.step(real_output) #update the p scheduler

            # Get discriminator loss
            real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
            disc_loss = real_loss + fake_loss

            # Add gradient penalty
            if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
                # Calculate and log gradient penalty
                gp = self.gradient_penalty(real_images, real_output)
                # Multiply by coefficient and add gradient penalty
                disc_loss = disc_loss + 0.5 * self.args.gradient_penalty_coefficient * gp * self.args.lazy_gradient_penalty_interval


            # Log discriminator loss
            [...]

            if (idx + 1) % self.args.log_generated_interval == 0:

                ######### VALIDATION #######
                ############################
                self.discriminator.eval()
                self.avg_log["avg_out_D_val"] = 0

                for j, data in enumerate(self.dataloader_val):
                    real_images_val, patches_val = data
                    real_images_val = real_images_val.to(self.args.device)
                    
                    val_output = self.discriminator(real_images_val)
                    self.avg_log["avg_out_D_val"] += torch.mean(val_output).item()

                self.discriminator.train()

                [...] #logging in tensorboard


            # Compute gradients
            disc_loss.backward()

[...] #similarly for the generator

UPDATE

I might find the issue: when computing Path Length Penalty (PLP) regularization, the following operation is performed in the generator training step:

        gradients, *_ = torch.autograd.grad(outputs=output,
                                            inputs=w,
                                            grad_outputs=torch.ones(output.shape, device=device),
                                            create_graph=True)

This probably cause a double gradient computation for w and the generated image x that ends up with the aforementioned error (as pointed out here).

I still don’t understand how it is correlated with ADA. No error is disable ADA, even if PLP is still present.

P.S. I based my code on this implementation of StyleGAN2.