SRGAN artifacts caused by gradient

Hi,

I am trying to implement SRGAN (super resolution).
I try to train my model on a batch of 10 images, just in order to see if I am able to correctly overfit.
Here are the results after 200 epochs on these 10 images :

I highly suspect my training loop, and my non understanding of how the gradient flows with pytorch.
Here is the code :

torch.autograd.set_detect_anomaly(True)

for epoch in range(epochs): 
    
    for i, batch in enumerate(train_data): 
        
        lr, hr = batch["lr"], batch["hr"]
        lr = lr.to(device)
        hr = hr.to(device)
        bs = lr.shape[0]
        # train discriminator
        
        
        label = torch.full((bs, ), 1, device=device, dtype=torch.float32)
        d_net.zero_grad()
        real_pred = d_net(hr).view(-1)
                
        real_loss = bce_criterion(real_pred,label)
        real_loss.backward()
        # change label 
        label = label.fill_(0.0)
        sr = g_net(lr)
        fake_pred = d_net(sr.detach()).view(-1)
        fake_loss = bce_criterion(fake_pred, label)
        fake_loss.backward()
        d_loss = fake_loss + real_loss
        d_optim.step()
        
        with torch.no_grad():
            p_real = real_pred.mean(axis=0)
            p_fake = fake_pred.mean(axis=0)
        
        # train generator
        
        # change label 
        g_net.zero_grad()
        label = label.fill_(1.)
        sr_preds = d_net(sr).view(-1)
        adv_loss = adv_coeff * bce_criterion(sr_preds, label)
        adv_loss.backward(retain_graph=True)
        content_criterion = content_loss(hr, sr)
        content_criterion.backward()
        g_loss = content_criterion + adv_loss
        g_optim.step()
        
        with torch.no_grad():
            p_fake_trick = sr_preds.mean(axis=0)

    
        # monitoring
                
        if i % 10 == 0: 
            print("d_loss = ", d_loss.item())
            print("g_loss = ", g_loss.item())
            print("D(Real) = ", p_real.item())
            print("D(Fake)2 = ", p_fake_trick.item())
            print("Content loss = ", content_criterion.item())

            with torch.no_grad(): 
                for test_batch in train_data: 
                    
                    lr_test, hr_test = test_batch["lr"], test_batch["hr"]
                    lr_test = lr_test.to(device)
                    hr_test = hr_test.to(device)
                    print("LR")
                    display_grid(lr_test.cpu())
                    sr = g_net(lr_test)
                    print("Generated")
                    display_grid(sr.cpu())
                    print("Real")
                    display_grid(hr_test.cpu())
                    break
            
    print("Epoch : ", epoch)

Thank you very very much for your help