I tried to implement torch.linalg.lstsq at the end of my CNN. Under default settings it usually outputs Nan so I switched the driver to ‘gelsy’. ‘gelsy’ supports rank-deficient matrics but can only run on CPU according to the documentation.
My codes are:
def _lstsq(self, A, B):
# A is the feature map extracted by CNN and B is the reference image
N, C_A, H, W, = A.size()
C_B = B.size(1)
A = A.permute([0,2,3,1]).reshape(N, H*W, C_A)
ones = torch.ones_like(A[:, :, 0:1])
A = torch.cat([ones, A], dim=2)
B = B.permute([0,2,3,1]).reshape(N, H*W, C_B)
params = torch.linalg.lstsq(A, B, driver='gelsy').solution
out = torch.nn.functional.relu(torch.bmm(A, params))
out = out.reshape(N, H, W, C_B)
out = out.permute([0, 3, 1, 2])
return out
# ......
def lstsq_layer(self, A, B):
A = A.cpu()
B = A.cpu()
N,C_A,H,W = A.size()
# normalize input A to [-1,1]
A = A.reshape(N, C_A, -1)
max_a = torch.max(A, dim=2, keepdim=True).values
min_a = torch.min(A, dim=2, keepdim=True).values
A = (2*A2-max_a-min_a)/(max_a-min_a)
A = A.reshape(N, C_A, height, width)
# calling torch.linalg.lstsq
output = self._lstsq(A, B)
output.to(device)
return output
It caused RuntimeError: falseINTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1646755903507/work/aten/src/ATen/native/LinearAlgebraUtils.h":328, please report a bug to PyTorch. torch.linalg.lstsq: Argument 4 has illegal value. Most certainly there is a bug in the implementation calling the backend library.
I then tried ‘gelsd’ and ‘gelss’; I also updated mkl with conda; the error was still.
Using ‘gels’ can make the codes get running but always ended in outputting Nans.