Out of memory increasing after each iteration

Hi,

I have the issue that GPU memory increases after each iteration and leads to OUT OF Memory. Here are some snippets from the training code:

    def gen_update(self, x_s, x_t, hyperparameters):
        self.opt_G.zero_grad()

        #encoder
        attr_s = self.G.encoder(x_s)
        attr_t = self.G.encoder(x_t)

        with torch.no_grad():
            id_s, x_s_feats = self.IdEncoder(F.interpolate(x_s[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
            id_t, x_t_feats = self.IdEncoder(F.interpolate(x_t[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))

        #decode (within person)
        x_s_recon = self.G.decoder(attr_s, id_s)
        x_t_recon = self.G.decoder(attr_t, id_t)

        #decode  (cross person)
        x_st = self.G.decoder(attr_s, id_t)
        x_ts = self.G.decoder(attr_t, id_s)

        #encoder again
        attr_s_recon = self.G.encoder(x_st)

        id_t_recon, x_st_feats = self.IdEncoder(F.interpolate(x_st[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))

        attr_t_recon = self.G.encoder(x_ts)

        id_s_recon, x_ts_feats = self.IdEncoder(F.interpolate(x_ts[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))

        #decode again
        x_sts = self.G.decoder(attr_s_recon, id_s)
        x_tst = self.G.decoder(attr_t_recon, id_t)

        #reconstruction loss
        loss_gen_recon_x_s = self.recong_criterion(x_s_recon, x_s)
        loss_gen_recon_x_t = self.recong_criterion(x_t_recon, x_t)
        loss_gen_recon_id_s = self.cosine_criterion(id_s_recon, id_s)
        loss_gen_recon_id_t = self.cosine_criterion(id_t_recon, id_t)
        loss_gen_recon_attr_s = self.attr_criterion(attr_s_recon, attr_s)
        loss_gen_recon_attr_t = self.attr_criterion(attr_t_recon, attr_t)
        loss_gen_cycrecon_x_s = self.recong_criterion(x_sts, x_s)
        loss_gen_cycrecon_x_t = self.recong_criterion(x_tst, x_t)

        #adv loss
        loss_adv_x_s_recon = self.dis_real.calc_gen_loss(x_s_recon, hyperparameters['dis']['gan_type'])
        loss_adv_x_t_recon = self.dis_real.calc_gen_loss(x_t_recon, hyperparameters['dis']['gan_type'])
        loss_adv_x_st = self.dis_real.calc_gen_loss(x_st, hyperparameters['dis']['gan_type'])
        loss_adv_x_ts = self.dis_real.calc_gen_loss(x_ts, hyperparameters['dis']['gan_type'])

        loss_gen_total =       hyperparameters['gan_w'] *  (loss_adv_x_s_recon + loss_adv_x_st \
                               + loss_adv_x_t_recon +loss_adv_x_ts) + \
                               hyperparameters['recon_x_w'] * (loss_gen_recon_x_s + loss_gen_recon_x_t) + \
                               hyperparameters['recon_identity_w'] *(loss_gen_recon_id_s + loss_gen_recon_id_t) + \
                               hyperparameters['recon_attr_w'] * (loss_gen_recon_attr_s +loss_gen_recon_attr_t) + \
                               hyperparameters['recon_cyc_w'] * (loss_gen_cycrecon_x_s +loss_gen_cycrecon_x_t)

        loss_gen_total.backward()
        self.opt_G.step()

Hi,
do you store any loss values for plotting, etc? And can you show us more of your train loop or discriminator updates?