Passing a SparseTensor object to an optimzer

So I’m just playing around with pytorch-sparse with some toy examples to get a sense of it as there isn’t a complete documentation for that.
Suppose that I have a sparse matrix A, and given two vector x and b. I know Ax=b.
Given b and x and the exact location of nnz in A, I want to find A through SGD, here is the pseudocode I wrote for that and seems it works.

import numpy as np
from torch_sparse import SparseTensor
import torch
n = 10
lr = 0.01
row = np.random.choice(n, n, replace=True)
col = np.random.choice(n, n, replace=True)
data = np.random.choice(np.arange(1, n + 1), n, replace=True)
A_ground_truth = SparseTensor(row = torch.Tensor(row).to(torch.long), col = torch.Tensor(col).to(torch.long), value = torch.Tensor(data).to(torch.float), sparse_sizes=(n, n))
A_approx = SparseTensor(row = torch.Tensor(row).to(torch.long), col = torch.Tensor(col).to(torch.long), sparse_sizes=(n, n)).requires_grad_(True) # This is the matrix A that we want to learn
x = torch.randn(n,1) # This is the vector x that we have, we know Ax = b
b = A_ground_truth.matmul(x) # This is the vector b that we have, we know Ax = b
for i in range(10000):
    output = A_approx.matmul(x)
    l = torch.norm(b - output)
    l.backward()
    row,col,value = A_approx.coo()
    print(f"Epoch: {i+1}, Loss: {l.item():.3f}")
    newvalue = value - lr*value.grad
    A_approx = SparseTensor(row = torch.Tensor(row).to(torch.long), col = torch.Tensor(col).to(torch.long), value=newvalue.detach(), sparse_sizes=(n, n)).requires_grad_(True)

However, I have some questions regard that:

  1. As you can see, in each epoch, I create A_approx again from the updated values, are there any solution to update A_approx values on the fly?
  2. What if I want to use an optimizer in torch.optim like torch.optim.ADAM? How can I pass the val in A_approx to the optimizer? Should I pass val or I can pass the whole SparseTensor object? Could you please provide me an snippet for that?

You can check optim.SparseAdam which accepts sparse parameters.