Hi,
I have the issue that GPU memory increases after each iteration and leads to OUT OF Memory. Here are some snippets from the training code:
def gen_update(self, x_s, x_t, hyperparameters):
self.opt_G.zero_grad()
#encoder
attr_s = self.G.encoder(x_s)
attr_t = self.G.encoder(x_t)
with torch.no_grad():
id_s, x_s_feats = self.IdEncoder(F.interpolate(x_s[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
id_t, x_t_feats = self.IdEncoder(F.interpolate(x_t[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
#decode (within person)
x_s_recon = self.G.decoder(attr_s, id_s)
x_t_recon = self.G.decoder(attr_t, id_t)
#decode (cross person)
x_st = self.G.decoder(attr_s, id_t)
x_ts = self.G.decoder(attr_t, id_s)
#encoder again
attr_s_recon = self.G.encoder(x_st)
id_t_recon, x_st_feats = self.IdEncoder(F.interpolate(x_st[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
attr_t_recon = self.G.encoder(x_ts)
id_s_recon, x_ts_feats = self.IdEncoder(F.interpolate(x_ts[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
#decode again
x_sts = self.G.decoder(attr_s_recon, id_s)
x_tst = self.G.decoder(attr_t_recon, id_t)
#reconstruction loss
loss_gen_recon_x_s = self.recong_criterion(x_s_recon, x_s)
loss_gen_recon_x_t = self.recong_criterion(x_t_recon, x_t)
loss_gen_recon_id_s = self.cosine_criterion(id_s_recon, id_s)
loss_gen_recon_id_t = self.cosine_criterion(id_t_recon, id_t)
loss_gen_recon_attr_s = self.attr_criterion(attr_s_recon, attr_s)
loss_gen_recon_attr_t = self.attr_criterion(attr_t_recon, attr_t)
loss_gen_cycrecon_x_s = self.recong_criterion(x_sts, x_s)
loss_gen_cycrecon_x_t = self.recong_criterion(x_tst, x_t)
#adv loss
loss_adv_x_s_recon = self.dis_real.calc_gen_loss(x_s_recon, hyperparameters['dis']['gan_type'])
loss_adv_x_t_recon = self.dis_real.calc_gen_loss(x_t_recon, hyperparameters['dis']['gan_type'])
loss_adv_x_st = self.dis_real.calc_gen_loss(x_st, hyperparameters['dis']['gan_type'])
loss_adv_x_ts = self.dis_real.calc_gen_loss(x_ts, hyperparameters['dis']['gan_type'])
loss_gen_total = hyperparameters['gan_w'] * (loss_adv_x_s_recon + loss_adv_x_st \
+ loss_adv_x_t_recon +loss_adv_x_ts) + \
hyperparameters['recon_x_w'] * (loss_gen_recon_x_s + loss_gen_recon_x_t) + \
hyperparameters['recon_identity_w'] *(loss_gen_recon_id_s + loss_gen_recon_id_t) + \
hyperparameters['recon_attr_w'] * (loss_gen_recon_attr_s +loss_gen_recon_attr_t) + \
hyperparameters['recon_cyc_w'] * (loss_gen_cycrecon_x_s +loss_gen_cycrecon_x_t)
loss_gen_total.backward()
self.opt_G.step()