Custom Top-eigenvector Function

I am trying to write a custom function that computes the dominant eigenvector and its derivative of a symmetric matrix using Eq. (68) in the matrix cookbook and numpy.

Here is my code:

import torch
from torch.autograd import Function, Variable, gradcheck
import numpy as np


def to_tensor(x):
    return torch.from_numpy(x).float()

class TopEigen(Function):

    @staticmethod
    def forward(ctx, matrix):
        cov = matrix.numpy()
        lams, vecs = np.linalg.eigh(cov)
        ctx.intrm = (lams, vecs, cov)
        return to_tensor(vecs[:, -1, None])

    @staticmethod
    def backward(ctx, grad_output):
        lams, vecs, cov = ctx.intrm
        output = grad_output.data.numpy()
        pinv = np.linalg.pinv(lams[-1]*np.eye(*cov.shape) - cov)
        grad_np = np.dot(np.dot(output.T, pinv).T, vecs[None, :, -1])
        return Variable(to_tensor(grad_np))


# Testing
topeig = TopEigen.apply

p, q = 5, 3
in_tensor = Variable(torch.rand(p, q), requires_grad=True)
cov_in = torch.mm(in_tensor.t(), in_tensor)

out = topeig(cov_in).mean()
out.backward(retain_graph=True)

test = gradcheck(topeig, (cov_in, ), eps=1e-6, atol=1e-4)
print(test)

When I run this script, I get the following error message:

RuntimeError: for output no. 0,
 numerical:(
-0.0680  0.0404  0.0532
 0.0000  0.0000  0.0000
 0.0000  0.0000  0.0000
-0.0015 -0.1016  0.0791
 0.0249 -0.0779  0.0285
 0.0000  0.0000  0.0000
-0.0019  0.0790 -0.0578
 0.0655 -0.0740 -0.0246
 0.0431  0.0375 -0.0817
[torch.FloatTensor of size 9x3]
,)
analytical:(
    0     0     0
    0     0     0
    0     0     0
    0     0     0
    0     0     0
    0     0     0
    0     0     0
    0     0     0
    0     0     0
[torch.FloatTensor of size 9x3]
,)

Why is the analytical output all 0? I check the code with double() tensors but the result was the same.

Can anyone help me with spotting the error in my code?

Hi,

I think you have the same problem as this post.
Basically, add a cov_in.retain_grad() before giving it to gradcheck.

2 Likes

Thank you. It resolved the issue.

It seems that the tutorial should be updated, right? Because I followed the exact steps of the tutorial.

I just made a PR in the master branch so that this issue does not happen anymore :slight_smile:

1 Like

@albanD, @tom : I am still not able to get this function right. Here is my new code:

import torch
from torch.autograd import Function, gradcheck
from torch import from_numpy
import numpy as np


class TopEigen(Function):
    """Copmpute the top eigenvector of a matrix."""

    @staticmethod
    def forward(ctx, matrix):
        cov = matrix.detach().numpy()
        lams, vecs = np.linalg.eigh(cov)
        ctx.intrm = lams, vecs
        return from_numpy(vecs[:, -1, None])

    @staticmethod
    def backward(ctx, grad_output):
        lams, vecs, cov = ctx.intrm
        output = grad_output.data.numpy()
        pinv = np.linalg.pinv(lams[-1]*np.eye(cov.shape[0]) - cov)
        return from_numpy(np.dot(np.dot(output.T, pinv).T, vecs[None, :, -1]))


# Create a random symmetric matrix
p, q = 5, 3
torch.manual_seed(0)
in_tensor = torch.rand(p, q, dtype=torch.float64, requires_grad=True)
cov_in = torch.mm(in_tensor.t(), in_tensor)

# Testing
topeig = TopEigen.apply
out = topeig(cov_in).mean()
test = gradcheck(topeig, (cov_in, ), eps=1e-10, atol=1e-3)
print(test)

Here is the output in pytorch 0.4.1:

Traceback (most recent call last):
  File "eigen.py", line 41, in <module>
    test = gradcheck(topeig, (cov_in, ), eps=1e-10, atol=1e-3)
  File ".../anaconda3/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 214, in gradcheck
    'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
  File ".../anaconda3/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 194, in fail_test
    raise RuntimeError(msg)
RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-0.0822,  0.0487,  0.0460],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [-0.0256, -0.0462,  0.0814],
        [ 0.0441, -0.0817,  0.0360],
        [ 0.0000,  0.0000,  0.0000],
        [-0.0219,  0.0801, -0.0611],
        [ 0.0819, -0.0387, -0.0568],
        [ 0.0381,  0.0329, -0.0820]], dtype=torch.float64)
analytical:tensor([[-0.0822,  0.0487,  0.0460],
        [-0.0743,  0.0441,  0.0416],
        [-0.0680,  0.0403,  0.0381],
        [ 0.0487, -0.0903,  0.0398],
        [ 0.0441, -0.0817,  0.0360],
        [ 0.0403, -0.0747,  0.0329],
        [ 0.0460,  0.0398, -0.0992],
        [ 0.0416,  0.0360, -0.0897],
        [ 0.0381,  0.0329, -0.0820]], dtype=torch.float64)

Interestingly, some rows are similar. Would you please spot my mistake?

Something’s funny in your formula. You could either use PyTorch master, which includes a derivative for symeig or compare your own calculation with that which has been implemented in PyTorch:

Best regards

Thomas

@tom: I understand that the current functionality is available in torch, but this piece of code is only a simplified version of what I want to implement with the goal of finding the bugs.

My main goal is to compute the top-1 eigenvector and its derivative of a batch of matrices.

Can you take a brief look and see my mistake in it?

Hi @tom. I found the issue. Eigen-decomposition requires a symmetric matrix as its input. But the perturbations by gradcheck() makes the input asymmetric and eigen-decomposition fails. I wrote my own numerical gradient checking function and the above function worked.

2 Likes

Ha cool. Glad you found it.

Cool, could you please show the final full code?

Hi, why do you compute gradients in this way?
@tom @taha should it be the kronecker product of pinv with vecs[None, :, -1]?

Here I was interested in the top eigenvector only using Eq. (68) in the matrix cookbook.

Here is my final code. However, I strongly recommend you to NOT use it; it is very slow. Please use the batch eigenvalue operations that are available starting from torch.1.2.

import torch
from torch.autograd import Function, gradcheck
from torch import from_numpy
import numpy as np

class TopEigen(Function):
    """Copmpute the top eigenvector of a matrix."""

    @staticmethod
    def flatten(bag, i):
        return tuple([item for sublist in bag for item in sublist[i]])

    @staticmethod
    def _bdiag(matrix):
        b, d = matrix.shape
        dlams = np.zeros((b, d, d))
        bag_ind = [[list(range(b)), [i]*b, [i]*b] for i in range(d)]
        index = tuple([TopEigen.flatten(bag_ind, i) for i in range(d)])
        dlams[index] = matrix.T.flatten()
        return dlams

    @staticmethod
    def forward(ctx, tensor):
        cov = tensor.detach().numpy()
        lams, vecs = np.linalg.eigh(cov)
        ctx.intrm = lams, vecs
        return from_numpy(vecs[:, :, -1, None])

    @staticmethod
    def backward(ctx, grad_output):
        lams, vecs = ctx.intrm
        output = grad_output.data.numpy()
        lams = lams[:, -1, None] - lams
        index = np.where(abs(lams) > 1e-6)
        lams[index] = 1/lams[index]
        dlams = TopEigen._bdiag(lams)
        pinv = np.matmul(vecs, np.matmul(dlams, vecs.transpose((0, 2, 1))))
        return from_numpy(np.matmul(np.matmul(output.transpose((0, 2, 1)),
                          pinv).transpose((0, 2, 1)), vecs[:, None, :, -1]))
1 Like