I get an error ( RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time) which requires me to set retain_graph=True. However if i do so i get an out of memory error. Is there someway to release the memory after the loss.backward(). Or some other way to work around the error other than retain_graph=True.
This is representattive of what I am trying to do :
latents = torch.randn( (batch_size, 4, height // 8, width // 8), generator=generator, ) latents = latents.to(torch_device, dtype=torch.float16) for j in range(num_iter): latents_new = return_unet(latents, text_embeddings) loss, model_input = return_model(latents_new) loss.backward() img2 = (model_input['pixel_values']) - lr_rate * (model_input['pixel_values'].grad) m = torch.nn.Upsample(scale_factor=3.5, mode='nearest') loss2 = sum(sum(sum(sum(img2 - m(latents_new))))) loss2.backward(retain_graph=True) text_embeddings = text_embeddings - lr_rate * text_embeddings.grad