How can I handle this error "svd_cuda: the updating process of SBDSDC did not converge"

Hi all
While I am using the following code on Cuda, where A and B are 2D tensors with shape [128,128].
C = A.pinverse().matmul(B)
I get the error, svd_cuda: the updating process of SBDSDC did not converge.
how can I handle this error?
Any help would be appreciated.

Well I would say it probably means your matrix is not invertible…?

You can check
https://www.planetmath.org/Pseudoinverse

1 Like

Do you mean by its dimensions nbyn is not investable? Or is it related to entities ‘I mean the vales’ of the matrix?
What is the solution? How can l compute the inverse of the matrix A?
Actually, I need the matrix C to calculate eigenvalues and eigenvectors

I mean, you should check that the matrix you want to invert accomplish the theoretical properties of an invertible matrix.

It looks like in your case, your matrix doesn’t. Moreover if A and B are square you can compute its inverse directly.
https://pytorch.org/docs/stable/generated/torch.linalg.inv.html#torch.linalg.inv

1 Like

Hi S!

First, to be clear, this is not about whether A is invertible. The whole
purpose of the pseudo-inverse (torch.pinverse()) is to perform a
mathematically useful operation even when the matrix is not invertible.
(And the pseudo-inverse coincides with the regular matrix inverse
(torch.inverse()) when the matrix is invertible.)

As to your specific problem, let me first note that I cannot reproduce
your error exact message (but I can cause other errors).

But things to try:

First check that A does not contain 'inf’s nor 'nan’s (A.isinf().any(),
A.isnan().any()). That could lead to a pinverse() error.

Try upgrading to the latest stable pytorch version or the nightly build.
(By the way, which pytorch version are you using?)

Try using pinverse()'s rcond argument. You will have to experiment
with the value, e.g., A.pinverse (rcond = 1.e-2), and see whether
larger values eliminate your error.

Replace your specific usage of pinverse() (namely multiplying B) with
C = torch.linalg.lstsq (A, B).solution.

Try perturbing A before calculating its pseudo-inverse:
A += 1.e-3 * torch.randn (A.shape). Again, play around with
the scale factor.

Perform the call to pinverse() on the cpu. It will be a pain if you need
to backpropagate through this step, because you will have to copy your
intermediate gradients over to the cpu, backpropagate on the cpu, and
then copy them back to the gpu.

Here is an illustration of some of these points:

>>> import torch
>>> torch.__version__
'1.10.0'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> A = torch.randn (5, 5, device = 'cuda')
>>> Az = m.clone()
>>> Az[:, 4] = 0
>>> Ai = m.clone()
>>> Ai[2, 3] = float ('inf')
>>>
>>> A
tensor([[-0.4984,  1.0116, -0.2725, -2.4033, -0.0222],
        [ 1.0673,  0.9031, -0.7990,  2.6576, -1.7713],
        [ 0.2904,  0.1456, -0.0956, -2.0614, -0.5158],
        [-0.0729, -0.0587,  0.0698,  0.0745,  0.8687],
        [-2.3846,  2.1474, -0.1243,  0.6857, -0.0301]], device='cuda:0')
>>> A.inverse()
tensor([[-2.7031e+00,  6.0249e-01,  4.3399e+00,  3.7649e+00,  8.2871e-01],
        [-3.4336e+00,  6.6966e-01,  5.5499e+00,  4.6267e+00,  1.5518e+00],
        [-9.0598e+00,  5.3106e-01,  1.2691e+01,  8.5048e+00,  3.4168e+00],
        [-2.7618e-01,  9.6305e-02, -4.5012e-04,  1.9232e-01,  9.4890e-02],
        [ 2.9279e-01,  4.4850e-02, -2.8051e-01,  1.0797e+00, -1.0827e-01]],
       device='cuda:0')
>>> torch.allclose (A.pinverse(), A.inverse(), atol = 1.e-6)
True
>>>
>>> Az
tensor([[-0.4984,  1.0116, -0.2725, -2.4033,  0.0000],
        [ 1.0673,  0.9031, -0.7990,  2.6576,  0.0000],
        [ 0.2904,  0.1456, -0.0956, -2.0614,  0.0000],
        [-0.0729, -0.0587,  0.0698,  0.0745,  0.0000],
        [-2.3846,  2.1474, -0.1243,  0.6857,  0.0000]], device='cuda:0')
>>> Az.inverse()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: inverse_cuda: (Batch element 0): The diagonal element 5 is zero, the inversion could not be completed because the input matrix is singular.
>>> Az.pinverse()
tensor([[-3.1374,  0.5360,  4.7560,  2.1632,  0.9893],
        [-3.9337,  0.5931,  6.0290,  2.7825,  1.7368],
        [-9.6315,  0.4435, 13.2387,  6.3964,  3.6282],
        [-0.3025,  0.0923,  0.0248,  0.0951,  0.1046],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]], device='cuda:0')
>>>
>>> Ai
tensor([[-0.4984,  1.0116, -0.2725, -2.4033, -0.0222],
        [ 1.0673,  0.9031, -0.7990,  2.6576, -1.7713],
        [ 0.2904,  0.1456, -0.0956,     inf, -0.5158],
        [-0.0729, -0.0587,  0.0698,  0.0745,  0.8687],
        [-2.3846,  2.1474, -0.1243,  0.6857, -0.0301]], device='cuda:0')
>>> Ai.pinverse()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: svd_cuda: (Batch element 0): The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: 6).
>>>
>>> B = torch.randn (5, 5, device = 'cuda')
>>> torch.allclose (A.pinverse().matmul (B), torch.linalg.lstsq (A, B).solution, atol = 1.e-6)
True

Good luck.

K. Frank

1 Like