Torch svd grad show all zero, when only use Vt[-1]

i use want to get homography use svd, as formula, i use the last vector of V, but torch gets grad all zero;
import torch

A_rows = [torch.randn(8, 9)] # 示例数据
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
A = torch.cat(A_rows).view(-1, 9).to(device)
A.requires_grad_(True)

U, S, Vt = torch.linalg.svd(A)

H = Vt[-1] / Vt[-1, -1]
H = H.view(3, 3)

A.retain_grad()
H.retain_grad()
Vt.retain_grad()

loss = H.norm()
print(“Loss:”, loss)

loss.backward()

print(“A.grad:”, A.grad)
print(“V.grad:”, Vt.grad)
print(“H.grad:”, H.grad)

then, i print H’s Vt’s A’s grad, show:
estimated_homo_matrix grad:
A.grad: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]], device=‘cuda:0’)
V.grad: tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 1.7672, -1.3057, 2.8489, 1.8927, 0.2041, 2.3516, 0.9954, -1.9848,
27.0288]], device=‘cuda:0’)
H.grad: tensor([[-0.3338, 0.2466, -0.5381],
[-0.3575, -0.0386, -0.4442],
[-0.1880, 0.3749, 0.1889]], device=‘cuda:0’)
why, all zeros?

Hi Tao!

The short story is that the singular-value decomposition is not necessarily unique,
so pytorch declines to attempt to compute certain gradients. (There is some nuance
here, so if you have specific questions, please ask.)

In your case, where A has fewer rows than columns, not all of Vt is unique (again,
some nuance), so pytorch just gives you zero for the .grad.

Quoting from the documentation for torch.linalg.svd:

Note

When full_matrices= True, the gradients with respect to U[…, :, min(m, n):] and Vh[…, min(m, n):, :] will be ignored, as those vectors can be arbitrary bases of the corresponding subspaces.

Here is a simplified version of your code that calls svd() with full_matrices = False:

import torch
print (torch.__version__)

_ = torch.manual_seed (2025)

A = torch.randn (2, 3, requires_grad = True)
print ('A = ...')
print (A)

U, S, Vt = torch.linalg.svd (A, full_matrices = False)   # compute "reduced" svd
print ('U.shape: ', U.shape, ', S.shape: ', S.shape, 'Vt.shape: ', Vt.shape)

print ('Vt = ...')
print (Vt)

H = Vt[-1] / Vt[-1, -1]

A.retain_grad()
H.retain_grad()
Vt.retain_grad()

loss = H.norm()
print('loss: ', loss)

loss.backward()

print('A.grad = ...')
print(A.grad)
print('Vt.grad = ...')
print(Vt.grad)
print('H.grad = ...')
print(H.grad)

And here is its output showing that in this case, gradients do flow back to A:

A = ...
tensor([[-0.8716,  0.1114,  1.2044],
        [-0.1803,  1.0021,  0.7914]], requires_grad=True)
U.shape:  torch.Size([2, 2]) , S.shape:  torch.Size([2]) Vt.shape:  torch.Size([2, 3])
Vt = ...
tensor([[-0.4441,  0.3981,  0.8027],
        [ 0.4845,  0.8603, -0.1586]], grad_fn=<LinalgSvdBackward0>)
loss:  tensor(6.3066, grad_fn=<LinalgVectorNormBackward0>)
A.grad = ...
tensor([[-24.4890,  -7.2269,  -1.7565],
        [ 13.0301, -22.3574,   7.9963]])
Vt.grad = ...
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 3.0556,  5.4255, 38.7727]])
H.grad = ...
tensor([-0.4845, -0.8603,  0.1586])

Best.

K. Frank