RuntimeError: cusolver error

Traceback (most recent call last):                                                                                                                                                                                                                                                                                        
  File "train.py", line 334, in <module>
    main(config, save_path)
  File "train.py", line 259, in main
    train_loss = train(train_loader, model, optimizer, epoch, w)
  File "train.py", line 170, in train
    w_loss_1 = w.compute()
  File "/home/nanana/mnt/nas/SR/lte/w2.py", line 114, in compute
    return w2_gaussian(self.mean_source, self.mean_target, self.cov_source, self.cov_target, eps=self.eps)
  File "/home/nanana/mnt/nas/w2.py", line 47, in w2_gaussian
    cov_source = make_psd(cov_source, strict=True, eps=eps)
  File "/home/nanana/mnt/nas/w2.py", line 32, in make_psd
    smallest_eig = matrices.min(-1)[0] if diag else min_eig(matrices)
  File "/home/nanana/mnt/nas/w2.py", line 27, in min_eig
    return torch.linalg.eigvalsh(matrices)[..., 0]
RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling `cusolverDnXsyevd( handle, params, jobz, uplo, n, CUDA_R_64F, reinterpret_cast<void*>(A), lda, CUDA_R_64F, reinterpret_cast<void*>(W), CUDA_R_64F, reinterpret_cast<void*>(bufferOnDevice), workspaceInBytesOnDevice, reinterpret_cast<void*>(bufferOnHost), workspaceInBytesOnHost, info)`

What’s the problem?

Could you post a minimal and executable code snippet reproducing the error, please?

import torch
import torch.nn as nn
import torch.nn.functional as F


STABILITY_CONST = 1e-8

def matrix_operator(matrices, operator):

    eigvals, eigvects = torch.linalg.eigh(matrices, UPLO='L')
    eigvals = torch.diag_embed(operator(eigvals))
    return eigvects @ eigvals @ eigvects.transpose(-2, -1)


def eye_like(matrices):
    
    return torch.eye(*matrices.shape[-2:-1], out=torch.empty_like(matrices)).expand_as(matrices)


def sqrtm(matrices):

    return matrix_operator(matrices, torch.sqrt)


def min_eig(matrices):

    return torch.linalg.eigvalsh(matrices)[..., 0]


def make_psd(matrices, strict = False, return_correction = False, diag = False, eps = STABILITY_CONST):

    smallest_eig = matrices.min(-1)[0] if diag else min_eig(matrices)
    small_positive_val = smallest_eig.clamp(max=0).abs()
    if strict: small_positive_val += eps
    if diag:
        res = matrices + small_positive_val[..., None]
    else:
        I = eye_like(matrices)
        res = matrices + I * small_positive_val[..., None, None]
    if return_correction:
        return res, small_positive_val
    return res


def w2_gaussian(mean_source, mean_target, cov_source, cov_target, eps = 1e-5):

    cov_source = make_psd(cov_source, strict=True, eps=eps)
    cov_target = make_psd(cov_target, strict=True, eps=eps)
    cov_target_sqrt = sqrtm(cov_target)
    mix = make_psd(cov_target_sqrt @ cov_source @ cov_target_sqrt, strict=False, eps=eps)

    mean_shift = torch.sum((mean_source - mean_target) ** 2, dim=-1)
    cov_shift_trace = torch.diagonal(cov_source + cov_target - 2 * sqrtm(mix), dim1=-2, dim2=-1).sum(dim=-1)
    return mean_shift + cov_shift_trace


class W2():

    def __init__(self, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.block_size = 3
    
    def update(self, images, distribution = "source"):
        
        samples = images.permute(0,2,1) # (B,C,HxW) -> (B,HxW,C)
        samples = samples.flatten(0,1).double() # (BxHxW,C)
        
        if distribution == "pred":
            self.source_features_sum = samples.sum(dim=-2)
            self.source_features_cov_sum = torch.einsum("...bi,...bj->...ij", samples, samples)
            self.source_features_num_samples = samples.size(-2)
        elif distribution == "gt":
            self.target_features_sum = samples.sum(dim=-2)
            self.target_features_cov_sum = torch.einsum("...bi,...bj->...ij", samples, samples)
            self.target_features_num_samples = samples.size(-2)
        else:
            raise NotImplementedError()

    @staticmethod
    def _mean_cov(sum, sum_corr, n):
        mean = sum / n
        cov = sum_corr - n * mean[..., None] @ mean[..., None, :]
        cov /= n - 1
        return mean, cov

    def compute(self):
        
        self.mean_source, self.cov_source = self._mean_cov(
            self.source_features_sum,
            self.source_features_cov_sum,
            self.source_features_num_samples
        )
        self.mean_target, self.cov_target = self._mean_cov(
            self.target_features_sum,
            self.target_features_cov_sum,
            self.target_features_num_samples
        )

        return w2_gaussian(self.mean_source, self.mean_target, self.cov_source, self.cov_target, eps=self.eps)

Thanks, I think the problem occurs in this.

Which method are we supposed to run to reproduce the issue?

gt_hpf_1 = torch.zeros(pred1.shape).cuda()
values_1 = gt_hpf[:,:,0].flatten().long()
z_coord_1, x_coord_1 = torch.where(torch.ones_like(gt_hpf[:,:,0]))
gt_hpf_1[z_coord_1, values_1, x_coord_1] = 1
  
w.update(gt_hpf_1, distribution="gt")
w.update(pred1, distribution="pred")
w_loss_1 = w.compute()
  
loss = w_loss_1

Thanks, This is it.