CycleGAN Error during training - Trying to backward through the graph a second time

Hi,

I hope you are all well…

I am trying to write a CycleGAN,

in the first epoch i want to use stable data, after this I would like to swich to the outputs of the 2 generators.

When i do this i get the following error:


**RuntimeError**: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

For the First Epoch
Generator A

        Gen_model_A.zero_grad()
        loss_A,out_A = loss_fn_I(Gen_model_A, real_data_A,lab_A)
        temp=loss_A.detach().cpu().numpy()
        if np.isnan(temp)==True:
            exit()
        #    print(out.element_size()*out.nelement())
        D_G_z2A = out_A.mean().item()
        loss_A.backward(retain_graph=True)
        optimizer_gen_A.step()

Generator B:

        Gen_model_B.zero_grad()
        loss_B,out_B = loss_fn(Gen_model_B, real_data_B,lab_B)
        temp=loss_B.detach().cpu().numpy()
        if np.isnan(temp)==True:
            exit()
        #    print(out.element_size()*out.nelement())
        D_G_z2B = out_B.mean().item()
        loss_B.backward(retain_graph=True)
        optimizer_gen_B.step()

Then from epoch 2 onwards:
Generator A:

        Gen_model_A.zero_grad()
        loss_A,out_A = loss_fn(Gen_model_A, out_B,lab_A)
        temp=loss_A.detach().cpu().numpy()
        if np.isnan(temp)==True:
            exit()
        #    print(out.element_size()*out.nelement())
        D_G_z2A = out_A.mean().item()
        loss_A.backward()
        optimizer_disc_A.step()

Generator B:

        Gen_model_B.zero_grad()
        loss_B,out_B = loss_fn_I(Gen_model_B, out_A,lab_B)
        temp=loss_B.detach().cpu().numpy()
        if np.isnan(temp)==True:
            exit()
        #    print(out.element_size()*out.nelement())
        D_G_z2B = out_B.mean().item()
        loss_B.backward()      <-----   This is where i get the error
        optimizer_disc_B.step()

I think it maybe that i am not passing the graph between eoch 1 & 2 and i was wondering if the problem would be solved using hidden.detach_(), but i cannot find anything in the pytorch documentation, can anyone help please?

Many thanks,

Chaslie