'RuntimeError: falseINTERNAL ASSERT FAILED' when using torch.linalg.lstsq

I tried to implement torch.linalg.lstsq at the end of my CNN. Under default settings it usually outputs Nan so I switched the driver to ‘gelsy’. ‘gelsy’ supports rank-deficient matrics but can only run on CPU according to the documentation.

My codes are:

    def _lstsq(self, A, B):
        # A is the feature map extracted by CNN and B is the reference image
        N, C_A, H, W, = A.size()
        C_B = B.size(1)

        A = A.permute([0,2,3,1]).reshape(N, H*W, C_A)
        ones = torch.ones_like(A[:, :, 0:1])
        A = torch.cat([ones, A], dim=2)
        B = B.permute([0,2,3,1]).reshape(N, H*W, C_B)

        params = torch.linalg.lstsq(A, B, driver='gelsy').solution

        out = torch.nn.functional.relu(torch.bmm(A, params))
        out = out.reshape(N, H, W, C_B)
        out = out.permute([0, 3, 1, 2])
        return out

    # ......

    def lstsq_layer(self, A, B):
        A = A.cpu()
        B = A.cpu()
        N,C_A,H,W = A.size()

        # normalize input A to [-1,1]
        A = A.reshape(N, C_A, -1)
        max_a = torch.max(A, dim=2, keepdim=True).values
        min_a = torch.min(A, dim=2, keepdim=True).values
        A = (2*A2-max_a-min_a)/(max_a-min_a)
        A = A.reshape(N, C_A, height, width)

        # calling torch.linalg.lstsq
        output = self._lstsq(A, B)
        return output

It caused RuntimeError: falseINTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1646755903507/work/aten/src/ATen/native/LinearAlgebraUtils.h":328, please report a bug to PyTorch. torch.linalg.lstsq: Argument 4 has illegal value. Most certainly there is a bug in the implementation calling the backend library.

I then tried ‘gelsd’ and ‘gelss’; I also updated mkl with conda; the error was still.

Using ‘gels’ can make the codes get running but always ended in outputting Nans.

Could you try to install the latest nightly binary and rerun your code? If you are still hitting this error could you create an issue on GitHub?

Problem solved. It was because the normalization: A = (2*A2-max_a-min_a)/(max_a-min_a). I forgot to add a small value to the denominator henceforth division by zero occurred. :rofl: