Hi all
i am working with the DRIT model code. I just copied it from GitHub and am trying to remove some unnecessary loss functions in the model part:
The modified model.py below:
import networks
import torch
import torch.nn as nn
class DRIT(nn.Module):
def init(self, opts):
super(DRIT, self).init()
# parameters
lr = 0.0001
lr_dcontent = lr / 2.5
self.nz = 8 # No. of output channel in conv2d
self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disContent = networks.Dis_content()
# encoders
self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b) # define the content encoder
self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz)
# generator
self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz)
# optimizers
self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=lr_dcontent, betas=(0.5, 0.999), weight_decay=0.0001)
self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
# Setup the loss function for training
self.criterionL1 = torch.nn.L1Loss()
def initialize(self):
self.disA.apply(networks.gaussian_weights_init)
self.disB.apply(networks.gaussian_weights_init)
self.disContent.apply(networks.gaussian_weights_init)
self.gen.apply(networks.gaussian_weights_init)
self.enc_c.apply(networks.gaussian_weights_init)
self.enc_a.apply(networks.gaussian_weights_init)
def set_scheduler(self, opts, last_ep=0):
self.disA_sch = networks.get_scheduler(self.disA_opt, opts, last_ep)
self.disB_sch = networks.get_scheduler(self.disB_opt, opts, last_ep)
self.disContent_sch = networks.get_scheduler(self.disContent_opt, opts, last_ep)
self.enc_c_sch = networks.get_scheduler(self.enc_c_opt, opts, last_ep)
self.enc_a_sch = networks.get_scheduler(self.enc_a_opt, opts, last_ep)
self.gen_sch = networks.get_scheduler(self.gen_opt, opts, last_ep)
def setgpu(self, gpu):
self.gpu = gpu
self.disA.cuda(self.gpu)
self.disB.cuda(self.gpu)
self.disContent.cuda(self.gpu)
self.enc_c.cuda(self.gpu)
self.enc_a.cuda(self.gpu)
self.gen.cuda(self.gpu)
def test_forward(self, image, a2b=True):
self.z_random = self.get_z_random(image.size(0), self.nz, ‘gauss’)
if a2b:
self.z_content = self.enc_c.forward_a(image)
output = self.gen.forward_b(self.z_content, self.z_random)
else:
self.z_content = self.enc_c.forward_b(image)
output = self.gen.forward_a(self.z_content, self.z_random)
return output
def test_forward_transfer(self, image_a, image_b, a2b=True):
self.z_content_a, self.z_content_b = self.enc_c.forward(image_a, image_b)
if self.concat:
self.mu_a, self.logvar_a, self.mu_b, self.logvar_b = self.enc_a.forward(image_a, image_b)
std_a = self.logvar_a.mul(0.5).exp_()
eps = self.get_z_random(std_a.size(0), std_a.size(1), ‘gauss’)
self.z_attr_a = eps.mul(std_a).add_(self.mu_a)
std_b = self.logvar_b.mul(0.5).exp_()
eps = self.get_z_random(std_b.size(0), std_b.size(1), ‘gauss’)
self.z_attr_b = eps.mul(std_b).add_(self.mu_b)
else:
self.z_attr_a, self.z_attr_b = self.enc_a.forward(image_a, image_b)
if a2b:
output = self.gen.forward_b(self.z_content_a, self.z_attr_b)
else:
output = self.gen.forward_a(self.z_content_b, self.z_attr_a)
return output
def forward(self):
# input images
half_size = 1 # take one image only
real_A = self.input_A
real_B = self.input_B
self.real_A_encoded = real_A[0:half_size] # batch 0
self.real_B_encoded = real_B[0:half_size]# batch 0
# get encoded z_c
self.z_content_a, self.z_content_b = self.enc_c.forward(self.real_A_encoded, self.real_B_encoded)#ok
self.z_attr_a, self.z_attr_b = self.enc_a.forward(self.real_A_encoded, self.real_B_encoded)
# first cross translation
input_content_forA = torch.cat((self.z_content_b, self.z_content_a),0)
input_content_forB = torch.cat((self.z_content_a, self.z_content_b),0)
input_attr_forA = torch.cat((self.z_attr_a, self.z_attr_a),0)
input_attr_forB = torch.cat((self.z_attr_b, self.z_attr_b),0)
output_fakeA = self.gen.forward_a(input_content_forA, input_attr_forA)
output_fakeB = self.gen.forward_b(input_content_forB, input_attr_forB)
self.fake_A_encoded, self.fake_AA_encoded = torch.split(output_fakeA, self.z_content_a.size(0), dim=0)
self.fake_B_encoded, self.fake_BB_encoded = torch.split(output_fakeB, self.z_content_a.size(0), dim=0)
# second cross translation
# get reconstructed encoded z_c
self.z_content_recon_b, self.z_content_recon_a = self.enc_c.forward(self.fake_A_encoded, self.fake_B_encoded)
self.z_attr_recon_a, self.z_attr_recon_b = self.enc_a.forward(self.fake_A_encoded, self.fake_B_encoded)
# second cross translation
self.fake_A_recon = self.gen.forward_a(self.z_content_recon_a, self.z_attr_recon_a)
self.fake_B_recon = self.gen.forward_b(self.z_content_recon_b, self.z_attr_recon_b)
def forward_content(self):
half_size = 1
self.real_A_encoded = self.input_A[0:half_size]
self.real_B_encoded = self.input_B[0:half_size]
# get encoded z_c
self.z_content_a, self.z_content_b = self.enc_c.forward(self.real_A_encoded, self.real_B_encoded)
def update_D_content(self, image_a, image_b):
self.input_A = image_a
self.input_B = image_b
self.forward_content()
self.disContent_opt.zero_grad()
loss_D_Content = self.backward_contentD(self.z_content_a, self.z_content_b)
self.disContent_loss = loss_D_Content.item()
nn.utils.clip_grad_norm_(self.disContent.parameters(), 5)
self.disContent_opt.step()
def update_D(self, image_a, image_b):
self.input_A = image_a
self.input_B = image_b
self.forward()
#########################################
# update disA
self.disA_opt.zero_grad()
loss_D1_A = self.backward_D(self.disA, self.real_A_encoded, self.fake_A_encoded)
self.disA_loss = loss_D1_A.item()
self.disA_opt.step()
# update disB
self.disB_opt.zero_grad()
loss_D1_B = self.backward_D(self.disB, self.real_B_encoded, self.fake_B_encoded)
self.disB_loss = loss_D1_B.item()
self.disB_opt.step()
# update disContent # just to udpadte the disContent twise
self.disContent_opt.zero_grad()
loss_D_Content = self.backward_contentD(self.z_content_a, self.z_content_b)
self.disContent_loss = loss_D_Content.item()
nn.utils.clip_grad_norm_(self.disContent.parameters(), 5)
self.disContent_opt.step()
def backward_D(self, netD, real, fake):
pred_fake = netD.forward(fake.detach())
pred_real = netD.forward(real)
loss_D = 0
for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)):
out_fake = torch.sigmoid(out_a)
out_real = torch.sigmoid(out_b)
all0 = torch.zeros_like(out_fake).cuda(self.gpu)
all1 = torch.ones_like(out_real).cuda(self.gpu)
ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0)
ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1)
loss_D = loss_D + (ad_true_loss + ad_fake_loss)
loss_D.backward()
return loss_D
def backward_contentD(self, imageA, imageB):
pred_fake = self.disContent.forward(imageA.detach())
pred_real = self.disContent.forward(imageB.detach())
for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)):
out_fake = torch.sigmoid(out_a)
out_real = torch.sigmoid(out_b)
all1 = torch.ones((out_real.size(0))).cuda(self.gpu)
all0 = torch.zeros((out_fake.size(0))).cuda(self.gpu)
ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1)
ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0)
loss_D = ad_true_loss + ad_fake_loss
loss_D.backward()
return loss_D
def update_EG(self):
# update G, Ec, Ea
self.enc_c_opt.zero_grad()
self.enc_a_opt.zero_grad()
self.gen_opt.zero_grad()
# content Ladv for generator
loss_G_GAN_Acontent = self.backward_G_GAN_content(self.z_content_a)
loss_G_GAN_Bcontent = self.backward_G_GAN_content(self.z_content_b)
# Ladv for generator
loss_G_GAN_A = self.backward_G_GAN(self.fake_A_encoded, self.disA)
loss_G_GAN_B = self.backward_G_GAN(self.fake_B_encoded, self.disB)
# cross cycle consistency loss
loss_G_L1_A = self.criterionL1(self.fake_A_recon, self.real_A_encoded) * 10 # we can remove cyclGAN
loss_G_L1_B = self.criterionL1(self.fake_B_recon, self.real_B_encoded) * 10 # we can remove cyclGAN
loss_G_L1_AA = self.criterionL1(self.fake_AA_encoded, self.real_A_encoded) * 10
loss_G_L1_BB = self.criterionL1(self.fake_BB_encoded, self.real_B_encoded) * 10
loss_G = loss_G_GAN_A + loss_G_GAN_B + loss_G_GAN_Acontent + loss_G_GAN_Bcontent + loss_G_L1_AA + loss_G_L1_BB + loss_G_L1_A + loss_G_L1_B
#do backward()
loss_G.backward(retain_graph=True)
#self.backward_EG()
# do optimisation
self.enc_c_opt.step()
self.enc_a_opt.step()
self.gen_opt.step()
self.gan_loss_a = loss_G_GAN_A.item()
self.gan_loss_b = loss_G_GAN_B.item()
self.gan_loss_acontent = loss_G_GAN_Acontent.item()
self.gan_loss_bcontent = loss_G_GAN_Bcontent.item()
self.l1_recon_A_loss = loss_G_L1_A.item()
self.l1_recon_B_loss = loss_G_L1_B.item()
self.l1_recon_AA_loss = loss_G_L1_AA.item()
self.l1_recon_BB_loss = loss_G_L1_BB.item()
self.G_loss = loss_G.item()
#def backward_EG(self):
def backward_G_GAN_content(self, data):
outs = self.disContent.forward(data)
for out in outs:
outputs_fake = torch.sigmoid(out)
all_half = 0.5*torch.ones((outputs_fake.size(0))).cuda(self.gpu)
ad_loss = nn.functional.binary_cross_entropy(outputs_fake, all_half)
return ad_loss
def backward_G_GAN(self, fake, netD=None):
outs_fake = netD.forward(fake)
loss_G = 0
for out_a in outs_fake:
outputs_fake = torch.sigmoid(out_a)
all_ones = torch.ones_like(outputs_fake).cuda(self.gpu)
loss_G = loss_G + (nn.functional.binary_cross_entropy(outputs_fake, all_ones))
return loss_G
def update_lr(self):
self.disA_sch.step()
self.disB_sch.step()
self.disContent_sch.step()
self.enc_c_sch.step()
self.enc_a_sch.step()
self.gen_sch.step()
def save(self, filename, ep, total_it):
state = {
‘disA’: self.disA.state_dict(),
‘disB’: self.disB.state_dict(),
‘disContent’: self.disContent.state_dict(),
‘enc_c’: self.enc_c.state_dict(),
‘enc_a’: self.enc_a.state_dict(),
‘gen’: self.gen.state_dict(),
'disA_opt': self.disA_opt.state_dict(),
'disContent_opt': self.disContent_opt.state_dict(),
'enc_c_opt': self.enc_c_opt.state_dict(),
'enc_a_opt': self.enc_a_opt.state_dict(),
'gen_opt': self.gen_opt.state_dict(),
'ep': ep,
'total_it': total_it
}
torch.save(state, filename)
return
def assemble_outputs(self):
images_a =self.real_A_encoded.detach()
images_a1 =self.fake_A_encoded.detach()
images_a2 = self.normalize_image(self.fake_AA_encoded).detach()
images_a3 = self.normalize_image(self.fake_A_recon).detach()
images_b = self.real_B_encoded.detach()
images_b1 = self.normalize_image(self.fake_B_encoded).detach()
images_b2 = self.normalize_image(self.fake_BB_encoded).detach()
images_b3 = self.normalize_image(self.fake_B_recon).detach()
row1 = torch.cat((images_a[0:1, ::], images_b1[0:1, ::], images_a2[0:1, ::], images_a3[0:1, ::]),3)
row2 = torch.cat((images_b[0:1, ::], images_a1[0:1, ::], images_b2[0:1, ::], images_b3[0:1, ::]),3)
return torch.cat((row1,row2),2)
and the train part is:
train
print(‘\n— train —’)
for ep in range(ep0, opts.n_ep):
gc.collect()
torch.cuda.empty_cache()
for it,(images_a, images_b) in enumerate(tqdm(train_loader, desc=‘Epoch: {}/{}’.format(ep, opts.n_ep))):
# input data
images_a = images_a.cuda(opts.gpu).detach()
images_b = images_b.cuda(opts.gpu).detach()
# update model
if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
model.update_D_content(images_a, images_b)
continue
else:
model.update_D(images_a, images_b)
model.update_EG()
print(‘total_it: %d (ep %d, it %d), lr %08f’ % (total_it, ep, it, model.gen_opt.param_groups[0][‘lr’]))
total_it += 1
model.update_lr()
# save result image
saver.write_img(ep, model)
# Save network weights
if (ep % 50 == 0):
saver.write_model(ep, total_it, model)
print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
after running the train part the runtime error below is coming:
RuntimeError Traceback (most recent call last)
Input In [6], in <cell line: 4>()
16 else:
17 model.update_D(images_a, images_b)
—> 18 model.update_EG()
20 # print(‘total_it: %d (ep %d, it %d), lr %08f’ % (total_it, ep, it, model.gen_opt.param_groups[0][‘lr’]))
21 total_it += 1
File ~\Project_All\DRIT-master\model.py:237, in DRIT.update_EG(self)
233 loss_G = loss_G_GAN_A + loss_G_GAN_B + loss_G_GAN_Acontent + loss_G_GAN_Bcontent + loss_G_L1_AA + loss_G_L1_BB + loss_G_L1_A + loss_G_L1_B
236 #do backward()
→ 237 loss_G.backward(retain_graph=True)
238 #self.backward_EG()
239
240
241 # do optimisation
242 self.enc_c_opt.step()
File ~\anaconda3\envs\Deeplearning\lib\site-packages\torch_tensor.py:363, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
354 if has_torch_function_unary(self):
355 return handle_torch_function(
356 Tensor.backward,
357 (self,),
(…)
361 create_graph=create_graph,
362 inputs=inputs)
→ 363 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File ~\anaconda3\envs\Deeplearning\lib\site-packages\torch\autograd_init_.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
168 retain_graph = create_graph
170 # The reason we repeat same the comment below is that
171 # some Python versions print out the first line of a multi-line function
172 # calls in the traceback and some print out the last line
→ 173 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 256, 64, 64]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
I deleted the inplace operation in the model part and also i changed the inplace = false in Rule operation but i still have this error!
How can i solve this issue?