I get an error ( RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time) which requires me to set retain_graph=True. However if i do so i get an out of memory error. Is there someway to release the memory after the loss.backward(). Or some other way to work around the error other than retain_graph=True.
This is representattive of what I am trying to do :
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device, dtype=torch.float16)
for j in range(num_iter):
latents_new = return_unet(latents, text_embeddings)
loss, model_input = return_model(latents_new)
loss.backward()
img2 = (model_input['pixel_values']) - lr_rate * (model_input['pixel_values'].grad)
m = torch.nn.Upsample(scale_factor=3.5, mode='nearest')
loss2 = sum(sum(sum(sum(img2 - m(latents_new)))))
loss2.backward(retain_graph=True)
text_embeddings = text_embeddings - lr_rate * text_embeddings.grad