Hello,
I have re-implemented StyleGAN2-ada since the original implementations was giving me many troubles. I have used torchvision.transforms.functional
to provide random data augmentation on real and fake images from the generator.
Unfortunately if am facing a problem, apparently at random during the training process (usually around iteration 15K). Still not able to understand how to reproduce the error I get:
Init ADA with p=0.01
10%|████▊ | 14687/150000 [13:37:00<125:27:07, 3.34s/it]
Traceback (most recent call last):
File "trainStyleGAN.py", line 581, in <module>
main()
File "trainStyleGAN.py", line 552, in main
trainer.train() #start training loop
File "trainStyleGAN.py", line 525, in train
self.step(i)
File "trainStyleGAN.py", line 445, in step
gen_loss.backward()
File "/home/alberto/GAN_venv/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/alberto/GAN_venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: derivative for aten::grid_sampler_2d_backward is not implemented
I cannot understand what is going on. I only now that is for sure related with ADA because if I run the code without any augmentation, no errors at all.
I add below some rows of my code. […] means I have skipped something to be short.
Augmenter Class:
class AugmentPipe(torch.nn.Module):
def __init__(self, init_p=0,
xflip=0, rotate90=0, xint=0, xint_max=0.125,
[...]
):
super().__init__()
self.p = init_p
# Pixel blitting.
self.xflip = float(xflip) # Probability multiplier for x-flip.
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
self.xint = float(xint) # Probability multiplier for integer translation.
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
[...] #more transforms here I am not reporting
def forward(self, images, masks=None, debug=False, args=None):
random.seed()
torch.random.seed()
if debug:
self.p = 1.0
if self.p == 0:
return images, masks
assert isinstance(images, torch.Tensor) and images.ndim == 4
self.device = images.device
batch_size, num_channels, height, width = images.shape
with_mask = masks is not None
if with_mask:
assert isinstance(images, torch.Tensor) and masks.device == self.device and masks.ndim == 4
batch_size_M, num_channels_M, height_M, width_M = masks.shape
if(self.xflip > 0):
apply_T = random.random()
i = torch.randint(0,2,[images.shape[0]], device=self.device)*self.xflip * self.p > apply_T
if torch.sum(i) > 0:
args.file_log.debug(f"[{args.idx} , {args.debug_str}] xflip: {i}")
images[i,:,:,:] = TF.hflip(images[i,:,:,:])
if masks is not None:
masks[i,:,:,:] = TF.hflip(masks[i,:,:,:])
if torch.any(images.isnan()):
print("flip is NaN")
if(self.rotate90 > 0):
apply_T = random.random()
i = torch.randint(0,2,[images.shape[0]], device=self.device) * self.rotate90 * self.p > apply_T
if torch.sum(i) > 0:
args.file_log.debug(f"[{args.idx} , {args.debug_str}] rotate90: {i}")
images[i,:,:,:] = TF.rotate(images[i,:,:,:], 90, torchvision.transforms.InterpolationMode.BILINEAR)
if masks is not None:
masks[i,:,:,:] = TF.rotate(masks[i,:,:,:], 90, torchvision.transforms.InterpolationMode.NEAREST)
if torch.any(images.isnan()):
print("rot90 is NaN")
if(self.xint > 0):
apply_T = random.random()
i = torch.randint(0,2,[images.shape[0]], device=self.device) * self.xint * self.p > apply_T
t = random.random() * self.xint_max
if torch.sum(i) > 0:
args.file_log.debug(f"[{args.idx} , {args.debug_str}] xint: {i}")
images[i,:,:,:] = TF.affine(images[i,:,:,:], 0, [t,t], 1, 0, torchvision.transforms.InterpolationMode.BILINEAR)
if masks is not None:
masks[i,:,:,:] = TF.affine(masks[i,:,:,:], 0, [t,t], 1, 0, torchvision.transforms.InterpolationMode.NEAREST)
if torch.any(images.isnan()):
print("xint is NaN")
[...]
return images, masks
def update_p(self, p):
self.p = p
Training step:
[...]
self.augmenter = AugmentPipe(init_p=self.p, **augpipe_specs['bgcfn']).requires_grad_(False).to(self.args.device)
[...]
def step(self, idx: int):
"""
### Training Step
"""
self.args.idx = idx
# Train the discriminator
# Reset gradients
self.discriminator_optimizer.zero_grad()
# Accumulate gradients for `gradient_accumulate_steps`
for i in range(self.args.gradient_accumulate_steps):
if self.args.use_ada and idx % self.args.p_update_interval == 0:
self.p = self.p_sched.get_p()
self.augmenter.update_p(self.p)
self.logger.add_scalar("[TRAIN] ADA PROB (p)", self.p, global_step=idx)
# Sample images from generator
generated_images, _ = self.generate_images(self.args.batch_size)
#Apply ADA
if self.args.use_ada:
self.args.debug_str = "DIC gen images"
generated_images, _ = self.augmenter(generated_images, args=self.args, debug=True)
# Discriminator classification for generated images
fake_output = self.discriminator(generated_images.detach())
self.avg_log["avg_out_D_fake"] += torch.mean(fake_output).item()
# Get real images from the data loader
real_images, patches = next(self.loader)
real_images = real_images.to(self.args.device)
debug_NaN("Real", real_images, self.logger, idx)
#APPLY ADA
real_images_original = real_images.clone()
if self.args.use_ada:
self.args.debug_str = "DIC real images"
real_images, _ = self.augmenter(real_images, args=self.args, debug=True)
debug_NaN("AUGMENTED Real", real_images, self.logger, idx)
# We need to calculate gradients w.r.t. real images for gradient penalty
if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
real_images.requires_grad_()
# Discriminator classification for real images
real_output = self.discriminator(real_images)
self.avg_log["avg_out_D_real"] += torch.mean(real_output).item()
if self.args.use_ada:
self.p_sched.step(real_output) #update the p scheduler
# Get discriminator loss
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
disc_loss = real_loss + fake_loss
# Add gradient penalty
if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
# Calculate and log gradient penalty
gp = self.gradient_penalty(real_images, real_output)
# Multiply by coefficient and add gradient penalty
disc_loss = disc_loss + 0.5 * self.args.gradient_penalty_coefficient * gp * self.args.lazy_gradient_penalty_interval
# Log discriminator loss
[...]
if (idx + 1) % self.args.log_generated_interval == 0:
######### VALIDATION #######
############################
self.discriminator.eval()
self.avg_log["avg_out_D_val"] = 0
for j, data in enumerate(self.dataloader_val):
real_images_val, patches_val = data
real_images_val = real_images_val.to(self.args.device)
val_output = self.discriminator(real_images_val)
self.avg_log["avg_out_D_val"] += torch.mean(val_output).item()
self.discriminator.train()
[...] #logging in tensorboard
# Compute gradients
disc_loss.backward()
[...] #similarly for the generator