Either "backward through graph a second time" error or out of memory error if i set retain_graph=True

I get an error ( RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time) which requires me to set retain_graph=True. However if i do so i get an out of memory error. Is there someway to release the memory after the loss.backward(). Or some other way to work around the error other than retain_graph=True.

This is representattive of what I am trying to do :

    latents = torch.randn(
        (batch_size, 4, height // 8, width // 8),
        generator=generator,
    )
    latents = latents.to(torch_device, dtype=torch.float16)

    for j in range(num_iter):
        latents_new = return_unet(latents, text_embeddings)
        loss, model_input = return_model(latents_new)
        loss.backward()

        img2 = (model_input['pixel_values']) - lr_rate * (model_input['pixel_values'].grad)
        m = torch.nn.Upsample(scale_factor=3.5, mode='nearest')
        loss2 = sum(sum(sum(sum(img2 - m(latents_new)))))
        loss2.backward(retain_graph=True)

        text_embeddings = text_embeddings - lr_rate * text_embeddings.grad

Based on the code it seems loss.backward() deletes the intermediates of return_unet and return_model, which will then cause the issue in loss2.backward() since it uses model_input which was created by using both models.
You could sum the losses together and call .backward() on the sum or you could call loss.backward(retain_graph=True) and afterwards loss.backward(), but I would probably prefer the former approach.

The reason I dont sum is that i need to get the gradients on model_input using loss.backward() first and then compute loss2 for gradients on text_embeddings. And strangely i get an error if i use retain_graph=True on loss.backward and not on loss2.backward. But I dont get an error if i use retain_graph=True on loss2.backward as in the code snippet above.