RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation -- when training GANs with DDP and BatchNorm

I am using DDP to create a simple GAN project (toy problem using MNIST). The thing is this error occurs when I try using the script:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

When setting torch.autograd.set_detect_anomaly(True), I saw the problem was with my BatchNormalization layer in the discriminator. Here is the code to my discriminator net:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

Here is how I wrap the model using DDP after initializing the process groups:

def compile(self, generator, discriminator):

        self.generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(generator)
        self.discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
        
        self.generator = DDP(self.generator.to(self.gpu), device_ids=[self.gpu])
        self.discriminator = DDP(self.discriminator.to(self.gpu), device_ids=[self.gpu])

And finally, here is my training loop:

for self.epoch in range(1, epochs+1):
            callback_handler.on_epoch_begin()
            self.generator.train()
            self.discriminator.train()
            
            for self.batch, (inputs, targets) in enumerate(self.train_loader):
                callback_handler.on_batch_begin()

                # training discriminator
                real_labels = 0.9 * torch.ones(inputs.size(0), 1).to(self.gpu)
                fake_labels = 0.1 * torch.zeros(inputs.size(0), 1).to(self.gpu)

                outputs_real = self.discriminator(inputs.unsqueeze(1).to(self.gpu).float())
                loss_real = self.loss_function(outputs_real, real_labels)

                z = torch.randn(inputs.size(0), 100).to(self.gpu)
                fake_images = self.generator(z)

                outputs_fake = self.discriminator(fake_images)
                loss_fake = self.loss_function(outputs_fake, fake_labels)

                self.loss_discriminator = loss_real + loss_fake
                self.optimizer_discriminator.zero_grad()
                self.loss_discriminator.backward()
                self.optimizer_discriminator.step()
                
                torch.distributed.all_reduce(self.loss_discriminator, op=torch.distributed.ReduceOp.AVG)

                # training generator
                z = torch.randn(inputs.size(0), 100).to(self.gpu)
                fake_images = self.generator(z)

                outputs = self.discriminator(fake_images)
                self.loss_generator = self.loss_function(outputs, real_labels)

                self.optimizer_generator.zero_grad()
                self.loss_generator.backward()
                self.optimizer_generator.step()
                torch.distributed.all_reduce(self.loss_generator, op=torch.distributed.ReduceOp.AVG)

                self.epoch_loss_generator += self.loss_generator
                self.epoch_loss_discriminator += self.loss_discriminator
                
                callback_handler.on_batch_end()

            self.train_loss_generator.append(self.epoch_loss_generator.item()/len(self.train_loader))
            self.train_loss_discriminator.append(self.epoch_loss_discriminator.item()/len(self.train_loader))

            self.epoch_loss_discriminator = torch.tensor(0.0).to(self.gpu)
            self.epoch_loss_generator = torch.tensor(0.0).to(self.gpu)

Weirdly enough, if I Run the code without the BatchNorm layer (DDP on and the same training loop described above) it runs just fine. If I disable DDP (With the BatchNorm on and same training loop above), it also runs fine. Finally, if i comment the second call of the discriminator (outputs_fake = self.discriminator(fake_images)) it also runs fine.

The problem seems to be with running DDP with a model that has BatchNorm layer and is called twice in the training loop (because if it’s called just once, the error vanishes). Any ideas?