Nan in U@torch.diag(S)@V.t() for huge data

I am trying to get back my input using SVD, getting nan values in U@torch.diag(S)@V.t()

Do you mind sharing a bit more about your data?

Also, check whether you’re using torch.svd or torch.linalg.svd they return different matrices.

A set of 1000 224*224 images of imagenet, using torch.svd

Would you mind sharing some code so it’s easier for me understand where bug appears?

Also, check if there are any NaNs in your data (via: torch.isnan(x).any()) before applying torch.svd. If you’re normalising your data check the variance is non-zero, because that’ll introduce a divide by zero error. And, I assume you’re doing all 1000 images in one go? So, you need to use torch.diag_embed(S) rather than torch.diag(S) instead!

Normalising the image by subtracting the mean image (of all the image), from each image.
X = images.reshape(nImage, h*w)
Norm_image = X - mean_image
U, S, V = torch.svd(torch.from_numpy(Norm_image.astype(np.double)).cuda(), some=False)
Checked that variance is non zero.

@ptrblck any solution for this.

Sorry, I didn’t get a ping for this. Could you check if any of the images are full of zeros? (after you apply the mean-centering?). And, check there are no infs or NaNs in your image before preconditioning.

images = torch.from_numpy(Norm_image.astype(np.double)).cuda() #shape [1000,255,255]


Also, could you try the following code and see if the NaN issue goes away?

#use torch.linalg.svd instead of torch.svd
#this returns V already transposed (so more efficient)
U, S, VT = torch.linalg.svd(images) 
images_from_svd = U @ S.diag_embed() @ VT 

Is this reshape right, perhaps you could be introducing the NaN when mean-centering your images? If Images is shape [B,W,H] you could just do image_mean = images.mean(dim=(-1,-2)) then image_norm = image - image_mean (which is then passed to torch.linalg.svd(image_norm))

Also, remember to write code wrapped around 3 backticks ``` so it’s shown nicely! :slight_smile:

The reason why I moved to torch.svd is because of gpu memory limit, torch.linalg.svd was not able to load the data . Is it possible to run it in data parallel mode.

Did you check the other things I mentioned?

You can just do images_from_svd comment I had above but just remember to transpose V