Hi,
The error is arrising in the function backward_G :
def backward_G(self):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
if self.opt.try_a == False and self.opt.no_identity_b:
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else :
self.idt_B = self.real_A
self.loss_idt_B = 0 #identical b removed because we expect the noise generator won't output same noisy input
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
if (self.opt.l0_reg) :
#self.loss_L0_reg = 0.000003 * torch.sum(-1 * torch.clamp(self.fake_B,-0.5,0.5) + 0.5) # TODO add parameter
image = -1 * torch.clamp( self.rgb2gray( self.fake_B), -0.5, 0.5) + 0.5
mask_toward_zero = image.clone()
mask_toward_one = image.clone()
mask_toward_zero[mask_toward_zero > 0.5] = 0
mask_toward_one[mask_toward_one < 0.5] = 1
self.loss_L0_reg = 0.0001 *( torch.sum( mask_toward_zero ) + torch.sum( 1 - mask_toward_one ) ) # TODO add parameter
else:
self.loss_L0_reg = 0
self.loss_scale_G_A = self.opt.lambda_scale_G_A * self.calc_scale_loss(self.real_A,self.fake_B)
self.loss_scale_G_B = self.opt.lambda_scale_G_B * self.calc_scale_loss(self.real_B, self.fake_A)
# GAN loss D_A(G_A(A))
self.loss_G_A = ( self.criterionGAN(self.netD_A(self.fake_B), True) )* self.opt.lambda_G_A + self.criterionGAN(self.netD_A(self.idt_A), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) #think
# if self.opt.try_a:
# self.loss_G_B = 0
# else:
# self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
#Forward cycle loss
# self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# # Backward cycle loss
# if self.opt.try_a:
# self.loss_cycle_B = 0
# else:
# self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# Forward cycle loss
self.loss_cycle_A = 0
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg + self.loss_scale_G_A + self.loss_scale_G_B
self.loss_G.backward()
In this function, I use a loss_scale_G_A and loss_scale_G_B (see second last line) which is calculated using the below function, and these losses are which are causing the problem.
def calc_scale_loss(self, real,fake):
list_of_scale = [1,2,4,8,16]
scale_factor = [0.0001,0.001,0.01,0.1,1]
#list_of_scale = [1]
#scale_factor = [1]
_ , __, orig_w, orig_h = real.shape
loss_scale = 0
for index, scale in enumerate(list_of_scale):
scaled_w = int( orig_w / scale )
scaled_h = int( orig_h / scale )
scaled_real = F.adaptive_avg_pool3d(self.rgb2gray(real),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC)
scaled_fake = F.adaptive_avg_pool3d(self.rgb2gray(fake),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC)
grad_scaled_real = F.conv2d(scaled_real, self.grad_conv_filter, padding=1) #TODO padding
grad_scaled_fake = F.conv2d(scaled_fake, self.grad_conv_filter, padding=1) # TODO padding
# my_filter = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]).cuda().unsqueeze(0)
# image_filter = F.conv2d(scaled_real, my_filter, padding=1)
# scaleed = image_filter / 9
# use_filter = (scaleed < 0.3).type(torch.cuda.FloatTensor)
# white = scaleed * use_filter + 1 * (1-use_filter)
# white = 1 - white
#grad_scaled_real.required_grad = False
#curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake * white, grad_scaled_real * white)
curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake , grad_scaled_real )
loss_scale += curr_loss #TODO factor (best for now it's 10)
#self.save_image2(grad_scaled_fake)
return loss_scale
If I remove these 2 losses from the total loss ie loss_G, then I am able to train the model.