What would be a proper way of implementing this riemann gradient?

So here’s a toy example from this paper,Automatic differentiation for Riemannian optimization on low-rank matrix and tensor-train manifolds:

import torch
import torch.nn as nn

def f(X):
    return torch.sum(X**2)

def g(delta_U, delta_V, U, V, f):
    perturbed_matrix = U @ delta_V.t() + delta_U @ V.t()
    return f(perturbed_matrix)

def compute_riemannian_gradient(X):
    U, S, V = torch.svd(X)
    delta_U = U @ torch.diag(S)
    delta_V = torch.zeros_like(V)
    delta_U.requires_grad_(True)
    delta_V.requires_grad_(True)
    perturbed_value = g(delta_U, delta_V, U, V, f)
    perturbed_value.backward()

    return delta_U.grad, delta_V.grad

def apply_gauge_conditions(delta_U, delta_V, V):
    delta_V -= V @ (V.t() @ delta_V)
    return delta_U, delta_V
def riemannian_gradient(X):
    U, _, V = torch.svd(X)
    delta_U, delta_V = compute_riemannian_gradient(X)
    delta_U, delta_V = apply_gauge_conditions(delta_U, delta_V, V)
    return delta_U @ V.t() + U @ delta_V.t()

X = torch.randn(5, 3)

for i in range(10):
    rgrad = riemannian_gradient(X)
    X = X - 0.01*rgrad
    print(f(X))

So as you can see, in the training inference, we don’t need the gradient of X, or [U, S, V]. Instead, I need the gradient from delta_U and delta_V to update X. Therefore I’m not able to simply loop through the parameters registered in parameters if I want to integrate this piece of code into torch.optimizer module.

My question is what’s the proper way of implementing this optimizing algorithm in optim.step() function when the weight X is updated by gradients from other parameters?