Hello everyone, I am implementing QR decomposition using modified Gram-Schmidt in a function and I am trying to get the gradient, but I am struggling, my function is:
def modified_gram_schmidt(A):
m, n = A.shape
Q = torch.zeros((m, n))
R = torch.zeros((n, n))
for j in range(n):
v = A[:, j]
for i in range(j):
R[i, j] = torch.dot(Q[:, i], v)
v = v - R[i, j] * Q[:, i]
R[j, j] = torch.norm(v)
Q[:, j] = v / R[j, j]
return Q, R
I am trying to compute its gradient using backward()
as follows:
Q, R = modified_gram_schmidt(A)
Q.sum().backward()
dQ1_dA = A.grad
print(dQ1_dA)
And it outputs None
as result, I think that this because the operations are not being tracked by Autograd
, I created A
with requireds_grad=True
, any suggestion or idea to compute it correctly will be appreciated.
Does the following work for you?
import torch
def modified_gram_schmidt(A):
m, n = A.shape
# Set requires grad so that gradients propagate through, clone to avoid
# in-placing over a leaf tensor
Q = torch.zeros((m, n), requires_grad=True).clone()
R = torch.zeros((n, n), requires_grad=True).clone()
for j in range(n):
v = A[:, j]
for i in range(j):
R[i, j] = torch.dot(Q[:, i], v)
v = v - R[i, j] * Q[:, i]
R[j, j] = torch.norm(v)
Q[:, j] = v / R[j, j]
return Q, R
A = torch.rand(3, 5, requires_grad=True)
# This context manager is available in versions >=2.0.0
with torch.autograd.graph.allow_mutation_on_saved_tensors():
Q, R = modified_gram_schmidt(A)
Q.sum().backward()
dQ1_dA = A.grad
print(dQ1_dA)
Thank you so much! It worked but the result doesn’t coincide with the one I got from taking the gradient on the torch.linalg.qr(A)
, I did it as follows:
Q1, R1 = torch.linalg.qr(A)# torch version
loss1 = Q1.sum()
loss1.backward()
dQ1_dA = A.grad
print(dQ1_dA)
A.grad.zero_()
I also set the random seed to the same number and ensure the matrix A is the same for functions, probably my way to get this last grad has an issue.
Does forward produce the same result? Is it possible that the function is poorly conditioned on certain inputs? Have you tried a simple square matrix example?
Checking the gradients with gradcheck, it seems fine:
A = torch.tensor([
[1., 2., 3.],
[5., 3., 2.],
[4., 1., 7.]
], requires_grad=True, dtype=torch.float64)
with torch.autograd.graph.allow_mutation_on_saved_tensors():
torch.autograd.gradcheck(modified_gram_schmidt, A)
You are right, the forward pass of my function doesn’t provide the same exact result for Q
and R
(Varies in some signs), but it passes the test when I compute Q@R
and compares it with the product obtained from torch
library, but it is difficult to get the same exact result provided by torch since their method is probably different.
And what do you mean by poorly condtioned?
I also ran the last code you sent and there was no errors.
Oh I mean ill-conditioned - the derivative at certain inputs may be steep which can cause the gradient computed analytically to differ significantly from the one computed numerically via finite differencing.
1 Like