Backward spits out error

I am trying to run the following code, and get stuck with the following issue:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I suspect that there is some issue with F_loss, since it seems to work if I instead replace in the assignment of G, F_loss with the actual loss function newly evaluated. Could someone explain to me this behaviour and help me to make the code run correctly? It runs if I just add the line

        F_loss = F_V_loss(X0, TA, dim, r1, sparse_loss)

before calling callback. But I am still surprised since I explicitly specified the requires_grad option in the first call.

import torch as to
import numpy as np
import torch.nn as nn
from __future__ import print_function
from scipy.stats import unitary_group
from torch.autograd import Variable


def F_V_loss(v, u, dim, parms1=0.0, sparsity=False):
    Id=to.eye(dim)
    if sparsity:
        return to.pow(to.norm(u-v,2),2) +  parms1*to.pow(to.norm(v,1),2)
    else:
        return to.pow(to.norm(u-v,2),2)

    
verbose = False
sparse_loss = True
reps = 1
dim = 100
r1 = 0.00 # regularizer for sparsity (L0) loss IFF sparse_loss == True

num_steps = 10000
T1 = to.empty(dim,dim)
T2 = to.empty(dim,dim)
Id=to.eye(dim)
Id2 = to.eye(2*dim)
TA = to.nn.init.orthogonal_(T1) # see https://pytorch.org/docs/0.3.1/_modules/torch/nn/init.html
UUT = U @ to.transpose(U,0,1)
tau = 0.01

X0 = Variable(to.nn.init.orthogonal_(T2), requires_grad=True) 
F_loss = F_V_loss(X0, TA, dim, r1, sparse_loss)
for t in range(1,num_steps+1):
    F_loss.backward(retain_graph=True)
    G = to.autograd.grad(F_loss, X0)[0].data.clone()
    X = X0.data.clone()
    W = to.mm(G,to.transpose(X,0,1)) - to.mm(X, to.transpose(G,0,1))
    # Y = to.mm((Id + 0.5*tau W).inverse,(Id - 0.5*tau W)) // to expensive to compute!
    U = to.cat((G,X),1)
    V = to.cat((X,-G), 1)
    YP = to.mm(U,to.mm((Id2 + 0.5*tau*to.mm(to.transpose(V,0,1),U)).inverse(), to.transpose(V,0,1)))
     X0.data = (X - 0.5*tau* to.mm(YP,X)).data.clone()