So here’s a toy example from this paper,Automatic differentiation for Riemannian optimization on low-rank matrix and tensor-train manifolds:
import torch
import torch.nn as nn
def f(X):
return torch.sum(X**2)
def g(delta_U, delta_V, U, V, f):
perturbed_matrix = U @ delta_V.t() + delta_U @ V.t()
return f(perturbed_matrix)
def compute_riemannian_gradient(X):
U, S, V = torch.svd(X)
delta_U = U @ torch.diag(S)
delta_V = torch.zeros_like(V)
delta_U.requires_grad_(True)
delta_V.requires_grad_(True)
perturbed_value = g(delta_U, delta_V, U, V, f)
perturbed_value.backward()
return delta_U.grad, delta_V.grad
def apply_gauge_conditions(delta_U, delta_V, V):
delta_V -= V @ (V.t() @ delta_V)
return delta_U, delta_V
def riemannian_gradient(X):
U, _, V = torch.svd(X)
delta_U, delta_V = compute_riemannian_gradient(X)
delta_U, delta_V = apply_gauge_conditions(delta_U, delta_V, V)
return delta_U @ V.t() + U @ delta_V.t()
X = torch.randn(5, 3)
for i in range(10):
rgrad = riemannian_gradient(X)
X = X - 0.01*rgrad
print(f(X))
So as you can see, in the training inference, we don’t need the gradient of X, or [U, S, V]. Instead, I need the gradient from delta_U and delta_V to update X. Therefore I’m not able to simply loop through the parameters registered in parameters if I want to integrate this piece of code into torch.optimizer module.
My question is what’s the proper way of implementing this optimizing algorithm in optim.step() function when the weight X is updated by gradients from other parameters?