# Function for Trace

Hello,

I am trying to write a Function for torch.trace (seems not to exist so far). Here is my code:

``````class CustomTrace(torch.autograd.Function):

def forward(self, mat):
self.save_for_backward(mat)

def backward(self, g):
mat, = self.saved_tensors

``````

When I try to check my result, its precision is not optimal, i.e.

``````input = (torch.autograd.Variable(torch.randn(2,2).double(), requires_grad = True),)
``````

returns true. However,

``````input = (torch.autograd.Variable(torch.randn(2,2).double(), requires_grad = True),)
``````

returns false. I would hope that my math is right. Is there any problem with my code or is this for some reason to be expected?

Thanks a lot!

Some suggestions:

• You dont need to save the entire `mat` for backwards, only `mat.size()`
• `mat.size()` -> `mat.size(0)`
• you dont handle `trace` for non-square inputs.
• the output return type is not the same as the input return type, because you use `torch.Tensor`. Instead use `mat.new()`

Overall, here’s a modified version that handles non-square inputs, and implements the other suggestions.

``````class CustomTrace(torch.autograd.Function):

def forward(self, input):
self.isize = input.size()
return input.new([torch.trace(input)])

isize = self.isize
``````
1 Like

Thanks a lot! Amazing to see how much these few lines of code can be improved.

Thanks for pointing out that trace also works for non-square matrices (the mathematician in me assumed that trace only works for square matrices). I will try to implement this.

Edit: I see. It already works for non-square matrices. Thanks a lot.

FYI I think there’s something a little strange going on with `gradcheck`. For example, check out what happens with this simple function:

`correlation_product = lambda f, d: torch.mul(d, torch.dot(f, d).expand_as(d))`

``````print(gradcheck(correlation_product, (f, d), eps=1e-6, atol=1e-3))
``````

prints `False, False, True` (`f` and `d` are declared as `V(torch.FloatTensor(np.random.randn(5)), requires_grad=True)`). My understanding is that this is using PyTorch’s built-in autograd, so it seems strange to have precision issues of this magnitude. This is on pytorch `0.1.11+8aa1cef`.

you should ideally do gradchecks in double precision. Float precision might not be enough for finite difference to agree with analytical gradient.

2 Likes

That was exactly the problem – thanks!

I am not sure if this was different in March 2017, but there is a ready to use trace function with autograd capabilities now: `torch.trace`.

Just mentioning for everyone that gets here using a search engine.

3 Likes