Autograd raise an error : NaNs encountered when trying to perform matrix-vector multiplication


(tritri) #1

Hi,

I’m using Gpytorch to implement a multi output regression, but i have an error when i try to use a Periodic kernel.

RuntimeError: NaNs encountered when trying to perform matrix-vector multiplication

It seems to work with the RBF kernel but as soon as i use the Periodic kernel, it doesn’t work anymore.

The definition of the class is the following :

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)

        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(),num_tasks=5)

        self.covar_module = gpytorch.kernels.MultitaskKernel(
                gpytorch.kernels.PeriodicKernel(),num_tasks=5, rank=1)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=5)
model = MultitaskGPModel(x_train, y_train, likelihood)

and my for loop is :

    for i in range(n_iter):
        optimizer.zero_grad()
        output = model(x_train)
        loss = -mll(output, y_train)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, n_iter, loss.item()))
        optimizer.step()

The error occurs when i call optimizer.step(), and i think that it is related to the differentiation of the variables and to the require_grat attribute, so i have tried to enable the grad in the data like that :

x_train = torch.tensor(x_train, requires_grad=True).float()
y_train  = torch.tensor(y_train, requires_grad=True).float()
x_test = torch.tensor(x_test, requires_grad=True).float()
y_test  = torch.tensor(y_test, requires_grad=True).float()

But it hasn’t changed anything…

It would be really nice to have some help ! Thank you so much !

PS : Pytorch version 1.0