torch._C._LinAlgError: cusolver error

torch._C._LinAlgError: cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED, when calling cusolverDnSgetrf( handle, m, n, dA, ldda, static_cast<float*>(dataPtr.get()), ipiv, info). This error may appear if the input matrix contains NaN. But, I don’t know what reason for it. I can calculate the solution with NumPy.
First, I calculate the NTK using the function from here.

K_ss = empirical_ntk(fnet_single, params, images_syn_batch, images_syn_batch, 'trace')
K_ss_reg = (K_ss + 1e-6 * torch.trace(K_ss) * torch.eye(K_ss.shape[0], device=args.device) / K_ss.shape[0])

Then, I use the result to calculate the solution.

label_syn_batch = F.one_hot(label_syn[batch_syn_id], num_classes=num_classes).to(args.device).float()
solve = torch.linalg.solve(K_ss_reg, label_syn_batch)

the images_syn_batch is the random number by using torch.randn and label_syn is the label from 0 to 9

I write this code based on the code here written by JAX.

k_ss = kernel_fn(x_support, x_support)
k_ts = kernel_fn(x_target, x_support)
k_ss_reg = (k_ss + jnp.abs(reg) * jnp.trace(k_ss) * jnp.eye(k_ss.shape[0]) / k_ss.shape[0])
pred = jnp.dot(k_ts, sp.linalg.solve(k_ss_reg, y_support, sym_pos=True))

Could you post a minimal, executable code snippet which would reproduce the issue as well as the output of python -m torch.utils.collect_env, please?

The env is

Collecting environment information...
PyTorch version: 1.11.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.27

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-81-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: A100-SXM4-40GB

Nvidia driver version: 450.51.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] functorch==0.1.0
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.3
[pip3] numpydoc==1.2
[pip3] torch==1.11.0
[pip3] torch-tb-profiler==0.1.0
[pip3] torchaudio==0.11.0
[pip3] torchlars==0.1.2
[pip3] torchvision==0.12.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] functorch                 0.1.0                    pypi_0    pypi
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.1            py38hd3c417c_0  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] mypy_extensions           0.4.3            py38h06a4308_1  
[conda] numpy                     1.20.3           py38hf144106_0  
[conda] numpy-base                1.20.3           py38h74d4b33_0  
[conda] numpydoc                  1.2                pyhd3eb1b0_0  
[conda] pytorch                   1.11.0          py3.8_cuda11.3_cudnn8.2.0_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     1.6.0+cu101              pypi_0    pypi
[conda] torch-tb-profiler         0.1.0                    pypi_0    pypi
[conda] torchaudio                0.11.0               py38_cu113    pytorch
[conda] torchlars                 0.1.2                    pypi_0    pypi
[conda] torchvision               0.7.0+cu101              pypi_0    pypi

The code is

import os
import copy
import torch
import random
import argparse
import warnings
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision.utils import save_image
from torchvision import datasets, transforms
from functorch import make_functional, vmap, jacrev
warnings.filterwarnings("ignore")

CENTER = True
DATA_PATH = 'data/'

class AlexNet(nn.Module):
    def __init__(self, channel, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.LocalResponseNorm(4, alpha=0.001 / 9.0, beta=0.75, k=1),
            nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(4, alpha=0.001 / 9.0, beta=0.75, k=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.classifier = nn.Sequential(
            nn.Linear(4096, 384),
            nn.ReLU(inplace=True),
            nn.Linear(384, 192),
            nn.ReLU(inplace=True),
            nn.Linear(192, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 4096)
        x = self.classifier(x)
        return x

def empirical_ntk(fnet_single, params, x1, x2, compute='full'):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T

    einsum_expr = None
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        assert False
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]) 
    result = result.sum(0)
    return result

def set_random_seeds(seed):
    np.random.seed(seed)
    torch.manual_seed(seed) 
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

def get_dataset(data_path, normalize=True):
    channel = 3
    im_size = (32, 32)
    num_classes = 10
    mean = [0.4914008, 0.482159  , 0.44653094]
    std = np.array([0.24703224, 0.24348514, 0.26158786])
    if normalize:
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    else:
        transform = transforms.Compose([transforms.ToTensor()])

    dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
    dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)

    return channel, im_size, num_classes, mean, std, dst_train, dst_test

def select_images(args, dst_train, num_classes, half_seed=None):
    ''' organize the real dataset '''
    
    images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
    images_all = torch.cat(images_all, dim=0)
    labels_all = torch.tensor([dst_train[i][1] for i in range(len(dst_train))], dtype=torch.long)
    images_all = images_all.to(args.device)

    indices_class = [[] for c in range(num_classes)]
    for i, lab in enumerate(labels_all):
        indices_class[lab].append(i)

    return images_all, labels_all, indices_class

# CUDA_VISIBLE_DEVICES=1
def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--save_name', type=str, default='', help='additional surffix')
    parser.add_argument('--model', type=str, default='AlexNet', help='model')

    # DC parameters
    parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
    parser.add_argument('--Iteration', type=int, default=5000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=4e-2, help='learning rate for updating synthetic images')
    parser.add_argument('--batch_real', type=int, default=128, help='batch size for real data')
    parser.add_argument('--batch_syn', type=int, default=32, help='batch size for real data')
    parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--decay_step', type=int, default=1)

    parser.add_argument('--half_seed', type=int, default=None)

    args = parser.parse_args()
    
    eval_it_pool = np.arange(0, args.Iteration+1, args.Iteration // args.decay_step).tolist()[1:]
    print(eval_it_pool)
    args.clean = True
    channel, im_size, num_classes, mean, std, dst_train, dst_test = get_dataset(DATA_PATH)
    args.clean = False

    images_all, labels_all, indices_class = select_images(args, dst_train, num_classes, half_seed=args.half_seed)
            
    org_testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=2)

    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))
    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

    ''' initialize the synthetic data '''
    image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, 
    requires_grad=True, device=args.device)
    label_syn = torch.arange(num_classes, device=args.device, dtype=torch.long).reshape(-1, 1).repeat(1, args.ipc).reshape(-1)
    
    if args.init == 'real':
        for c in range(num_classes):
            np.random.seed(42)
            idx_shuffle = np.random.permutation(indices_class[c])[:args.ipc]
            image_syn.data[c*args.ipc:(c+1)*args.ipc] = images_all[idx_shuffle].detach().data

    ''' training '''
    optimizer_img = torch.optim.Adam([image_syn,], lr=args.lr_img) # optimizer_img for synthetic data
    optimizer_img.zero_grad()
    # tqdm_range = trange(1, args.Iteration+1, desc='Loss', leave=True)

    criterion = torch.nn.MSELoss()

    for it in range(1, args.Iteration+1):
        ''' Evaluate synthetic data '''

        ''' Train synthetic data '''
        net = AlexNet().to(args.device)
        # net = ConvNetNTK(channel=channel, num_classes=num_classes, 
        #                             net_width=128, net_depth=3, net_act='relu', 
        #                             net_norm='none', net_pooling='avgpooling', im_size=im_size).to(args.device)
        net.train()
        # def reduce_logits(x):
        #     out = net(x)
        #     return torch.sum(out, dim=1) / (num_classes ** (1/2))

        ''' update synthetic data '''

        batch_real_id = np.random.choice(len(images_all), args.batch_real, replace=False)
        images_real_batch = images_all[batch_real_id]
        label_real_batch = F.one_hot(labels_all[batch_real_id], num_classes=num_classes).to(args.device).float()
        batch_syn_id = np.random.choice(len(image_syn), args.batch_syn, replace=False)
        images_syn_batch = image_syn[batch_syn_id]
        label_syn_batch = F.one_hot(label_syn[batch_syn_id], num_classes=num_classes).to(args.device).float()
        if CENTER:
            label_syn_batch -= 1/ num_classes
            label_real_batch -= 1/ num_classes

        fnet, params = make_functional(net)

        def fnet_single(params, x):
            return fnet(params, x.unsqueeze(0)).squeeze(0)

        K_ss = empirical_ntk(fnet_single, params, images_syn_batch, images_syn_batch, 'trace')
        K_ts = empirical_ntk(fnet_single, params, images_real_batch, images_syn_batch, 'trace')

        K_ss_reg = (K_ss + 1e-6 * torch.trace(K_ss) * torch.eye(K_ss.shape[0], device=args.device) / K_ss.shape[0])
        solve = torch.linalg.solve(K_ss_reg, label_syn_batch)
        # solve, _ = torch.solve(label_syn_batch, K_ss_reg)
        pred = torch.mm(K_ts, solve).to(args.device)
        loss = criterion(pred, label_real_batch)
        acc = torch.mean((torch.argmax(pred, dim=1) == torch.argmax(label_real_batch, dim=1)).float())

        optimizer_img.zero_grad()
        loss.backward()
        optimizer_img.step()