Function for Trace


I am trying to write a Function for torch.trace (seems not to exist so far). Here is my code:

class CustomTrace(torch.autograd.Function):
    def forward(self, mat):
        return torch.Tensor([torch.trace(mat)])
    def backward(self, g):
        mat, = self.saved_tensors
        return torch.mul(torch.eye(int(mat.size()[0])),g[0])

When I try to check my result, its precision is not optimal, i.e.

input = (torch.autograd.Variable(torch.randn(2,2).double(), requires_grad = True),)
torch.autograd.gradcheck(CustomTrace(), input, eps = 1e-4, atol = 1e-3)

returns true. However,

input = (torch.autograd.Variable(torch.randn(2,2).double(), requires_grad = True),)
torch.autograd.gradcheck(CustomTrace(), input, eps = 1e-6, atol = 1e-4)

returns false. I would hope that my math is right. Is there any problem with my code or is this for some reason to be expected?

Thanks a lot!

Your function looks correct.
Some suggestions:

  • You dont need to save the entire mat for backwards, only mat.size()
  • mat.size()[0] -> mat.size(0)
  • you dont handle trace for non-square inputs.
  • the output return type is not the same as the input return type, because you use torch.Tensor. Instead use

Overall, here’s a modified version that handles non-square inputs, and implements the other suggestions.

class CustomTrace(torch.autograd.Function):
    def forward(self, input):
        self.isize = input.size()
    def backward(self, grad_output):
        isize = self.isize
        grad_input =*isize))
        return grad_input
1 Like

Thanks a lot! Amazing to see how much these few lines of code can be improved.

Thanks for pointing out that trace also works for non-square matrices (the mathematician in me assumed that trace only works for square matrices). I will try to implement this.

Edit: I see. It already works for non-square matrices. Thanks a lot.

FYI I think there’s something a little strange going on with gradcheck. For example, check out what happens with this simple function:

correlation_product = lambda f, d: torch.mul(d,, d).expand_as(d))

print(gradcheck(correlation_product, (f, d), eps=1e-6, atol=1e-3))
print(gradcheck(correlation_product, (f, d), eps=1e-6, atol=1e-2))
print(gradcheck(correlation_product, (f, d), eps=1e-6, atol=1e-1))

prints False, False, True (f and d are declared as V(torch.FloatTensor(np.random.randn(5)), requires_grad=True)). My understanding is that this is using PyTorch’s built-in autograd, so it seems strange to have precision issues of this magnitude. This is on pytorch 0.1.11+8aa1cef.

you should ideally do gradchecks in double precision. Float precision might not be enough for finite difference to agree with analytical gradient.


That was exactly the problem – thanks!

I am not sure if this was different in March 2017, but there is a ready to use trace function with autograd capabilities now: torch.trace.

Just mentioning for everyone that gets here using a search engine.