for epoch in range(args.start_epoch,args.max_epoch+1):
generator.train()
for batch_id, (img,mask,ohmask) in enumerate(trainloader):
if args.nogpu:
img,mask,ohmask = Variable(img),Variable(mask),Variable(ohmask)
else:
img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda()),Variable(ohmask.cuda())
i = len(trainloader)*(epoch-1) + batch_id
############################
# Semi-Supervised Training #
###########################
if args.mode == 'semi':
## TODO: Extend random interleaving for split of any size
mid = args.batch_size // 2
img_1,mask_1,ohmask_1 = img[0:mid,...],mask[0:mid,...],ohmask[0:mid,...]
img_2,mask_2,ohmask_2 = img[mid:,...],mask[mid:,...],ohmask[mid:,...]
# Random Interleaving
if random.random() <0.5:
img_l,mask_l,ohmask_l = img_1,mask_1,ohmask_1
img_ul,mask_ul,ohmask_ul = img_2,mask_2,ohmask_2
else:
img_ul,mask_ul,ohmask_ul = img_1,mask_1,ohmask_1
img_l,mask_l,ohmask_l = img_2,mask_2,ohmask_2
################################################
# Labelled data for Discriminator Training #
################################################
out_img_map = generator(Variable(img_l.data,volatile=True))
out_img_map = nn.Softmax2d()(out_img_map)
N = out_img_map.size()[0]
H = out_img_map.size()[2]
W = out_img_map.size()[3]
# Generate the Real and Fake Labels
target_fake = Variable(torch.zeros((N,H,W)).long())
target_real = Variable(torch.ones((N,H,W)).long())
if not args.nogpu:
target_fake = target_fake.cuda()
target_real = target_real.cuda()
# Train on Real
conf_map_real = nn.LogSoftmax()(discriminator(ohmask_l.float()))
optimizer_D.zero_grad()
# Perform Label smoothing
if args.d_label_smooth != 0:
LD_real = (1 - args.d_label_smooth)*nn.NLLLoss2d()(conf_map_real,target_real)
LD_real += args.d_label_smooth * nn.NLLLoss2d()(conf_map_real,target_fake)
else:
LD_real = nn.NLLLoss2d()(conf_map_real,target_real)
LD_real.backward()
# Train on Fake
conf_map_fake = nn.LogSoftmax()(discriminator(Variable(out_img_map.data)))
LD_fake = nn.NLLLoss2d()(conf_map_fake,target_fake)
LD_fake.backward()
# Update Discriminator weights
poly_lr_scheduler(optimizer_D, args.d_lr, i)
optimizer_D.step()
###########################################
# labelled data Generator Training #
###########################################
out_img_map = generator(img_l)
out_img_map_smax = nn.Softmax2d()(out_img_map)
out_img_map_lsmax = nn.LogSoftmax()(out_img_map)
conf_map_fake = nn.LogSoftmax()(discriminator(out_img_map_smax))
LG_ce = nn.NLLLoss2d()(out_img_map_lsmax,mask_l)
LG_adv = nn.NLLLoss2d()(conf_map_fake,target_real)
################################
# Use unlabelled data to get L_semi #
################################
out_img_map = generator(img_ul)
soft_pred = nn.Softmax2d()(out_img_map)
hard_pred = torch.max(soft_pred,1)[1].squeeze(1)
conf_map = nn.Softmax2d()(discriminator(Variable(soft_pred.data,volatile=True)))
idx = np.zeros(out_img_map.data.cpu().numpy().shape,dtype=np.uint8)
idx = idx.transpose(0, 2, 3, 1)
conf_mapn = conf_map[:,1,...].data.cpu().numpy()
hard_predn = hard_pred.data.cpu().numpy()
idx[conf_mapn > args.t_semi] = np.identity(21, dtype=idx.dtype)[hard_predn[ conf_mapn > args.t_semi]]
out_img_map_lsmax = nn.LogSoftmax()(out_img_map)
LG_semi_arr = out_img_map_lsmax.masked_select(Variable(torch.from_numpy(idx).byte().cuda()))
LG_semi = -1*LG_semi_arr.mean() if len(LG_semi_arr.data.cpu().numpy()) != 0 else Variable(torch.zeros(1).cuda())
LG_seg = LG_ce + args.lam_adv *LG_adv + args.lam_semi*LG_semi
optimizer_G.zero_grad()
LG_seg.backward()
poly_lr_scheduler(optimizer_G, args.g_lr, i)
optimizer_G.step()
print("[{}][{}] LD: {:.4f} LD_fake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
.format(epoch,i,(LD_real + LD_fake).data[0],LD_real.data[0],LD_fake.data[0],LG_seg.data[0],LG_ce.data[0],LG_adv.data[0],LG_semi.data[0]))
I believe the problem is here,
LG_semi = -1*LG_semi_arr.mean() if len(LG_semi_arr.data.cpu().numpy()) != 0 else Variable(torch.zeros(1).cuda())
Whenever len(LG_semi_arr.data.cpu().numpy()) == 0
, the forward pass out_img_map = generator(img_ul)
keeps using the GPU memory.
Please write back if you need more details.