Inject cusomized gradient from the output

Here’s my problem:
I have a loss function composed of two terms, the first of which is correct but the second results from a function that I have to integrate. This integral is (imo) intractable, but this isn’t really a problem since I have the expression for the gradient in this case (relative to the model output).
So I’d like to calculate a partial loss, composed of the known term, and then when starting the backpropagation, add the gradient of the known loss to the known gradient I had beforehand.
Here’s an example code:


class FR_loss():
    
    def __init__(self,G,X,k):
        self.X = X
        self.G = G 
        self.k = k 
        
        self.dists = {ord_edge((u,v)) : torch.norm(self.X[v]-self.X[u])
         for u in range(len(self.G)) for v in range(u+1,len(self.G))}
        self.disp = self.compute_disp()

    def compute_disp(self):
        disp = dict()
        for v in self.G:
            D_v = 0 
            not_v = list(set(self.G.nodes()) - {v})
            for u in not_v:
                e_uv = (self.X[v]-self.X[u])/self.dists[ord_edge((u,v))]
                D_v += -((self.k**2)/self.dists[ord_edge((u,v))])*e_uv
            for u in self.G.neighbors(v):
                e_uv = (self.X[v]-self.X[u])/self.dists[ord_edge((u,v))]
                D_v += ((self.dists[ord_edge((u,v))]**2)/self.k)*e_uv
            disp[v] = D_v
        return disp
    
    def compute_remaining_grad(self,t,lr):
        grad_E_v_larger = torch.zeros((len(self.G),self.X.shape[1]))
        
        for v in self.G:
            not_v = list(set(self.G.nodes()) - {v})
            D_v = self.disp[v]
            if torch.norm(D_v) > t : 
                grad_E_v_larger[v] += t*(D_v/torch.norm(D_v))
        
        return grad_E_v_larger*1./lr 
        
    def compute_loss_partial(self,t,lr):
        E = 0
        for v in self.G:
            E_v = 0
            not_v = list(set(self.G.nodes()) - {v})
            D_v = self.disp[v]
            if torch.norm(D_v) < t : 
                for u in not_v : 
                    E_v += -(self.k**2)*torch.log(self.dists[ord_edge((u,v))]) 
                for u in self.G.neighbors(v):
                    E_v += (1./(3*self.k))*((self.dists[ord_edge((u,v))])**3)
            E += E_v
       
        return E*1./lr

Then I run :

def train_emb_FR(model, data, k, t, optimizer):
    
    torch.set_grad_enabled(True)
    model.train()
    G = nx_from_torch(data.edge_index)
    Z = model(data.x)
    Loss = FR_loss(G, Z,k)

    loss = Loss.compute_loss_partial(t, optimizer.defaults['lr'])
    remain_grad = Loss.compute_remaining_grad(t, optimizer.defaults['lr'])
    optimizer.zero_grad()
    
    part_gradients = torch.autograd.grad(loss, Z, retain_graph=True)[0]
    if loss == 0 : 
        Z.grad = remain_grad
    else : 
        Z.grad = part_gradients+remain_grad

    loss.backward()
    optimizer.step()
    return loss.item()

But it does not work the way i want it to…