Jacobian with respect to a symmetric tensor

Hello,

I wanted to perform gradient operation of a 3-by-3 tensor with respect to another 3-by-3 tensor, which outputs a 3-by-3-by-3-by-3 tensor, see the following example code:

X = torch.tensor([[1,3,5],[3,1,7],[5,7,1]],dtype=torch.double)
X.requires_grad_(True)

def computeY(input):return torch.pow(input, 2)

dYdX = torch.autograd.functional.jacobian(computeY, X) 

This does exactly what the Jacobian operation does, however, it does not seem to take into consideration that X is symmetric.
If X is a general matrix without symmetry information, then for example, dY[1,2]/dX[2,1] = 0. This is exactly what autograd does by default.
If X is a symmetric matrix such that X[i,j] = X[j,i], then for example, dY[1,2]/dX[2,1] = dY[2,1]/dX[2,1] = dY[1,2]/dX[1,2] = 2X[1,2] = 2X[2,1] = 6. Does anyone know if there’s anyway to provide such symmetry information to autograd?

Many thanks!

Hi Bruce!

I don’t believe that there is any way to “tell” autograd that some matrix
is to be understood to be a symmetric matrix. You can, however, rewrite
computeY() to treat its input as symmetric.

This may or may not fit your use case, but consider this illustrative script:

import torch
print (torch.__version__)

X = torch.tensor ([[1, 3, 5], [3, 1, 7], [5, 7, 1]], dtype = torch.double)
X_tril = X.tril()   # represent symmetric matrix by its lower triangle

print (X_tril)

def computeY (input) :   # assume input is lower-triangular
    input_sym = input.tril() + input.T.triu (1)   # build explicitly symmetric representation of input
    return torch.pow (input_sym, 2)

dYdX = torch.autograd.functional.jacobian (computeY, X_tril)   # only lower triangle of dX part is non-zero

print (dYdX[1, 2, 1, 2])   # X_tril[1, 2] is ignored (and is zero)
print (dYdX[1, 2, 2, 1])   # X_tril[2, 1] is used for both input_sym[2, 1] and input_sym[1, 2]

Here is its output:

2.1.2
tensor([[1., 0., 0.],
        [3., 1., 0.],
        [5., 7., 1.]], dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(14., dtype=torch.float64)

Best.

K. Frank

Hi K.Frank

Thanks for the respond, I have an idea that might be about to pass the symmetry information, just still trying to make it work in autograd if it’s possible, here’s my idea: suppose we define tensor X and tell autograd that it requires gradient info, we compute function value Y(X), and then we compute another function value Z(Y), and then finally compute the Jacobian dZ/dY instead of dZ/dX, is this possible? If so I could do something like:

def Y(symmetric_tensor): return 0.5 * ( symmetric_tensor.T + symmetric_tensor )

and then Z(Y) would be the actual function, for example, the square of the input, to be taken jacobian of, this should provide symmetry information in the way that for example, dZ[1,2]/dY[21] = dZ(0.5(X[1,2]+X[2,1])) / d (0.5(X[1,2]+X[2,1])) != 0

import torch
X = torch.tensor([[1,3,5],[3,1,7],[5,7,1]],dtype=torch.double)

X.requires_grad_(True)

Y = 0.5 * (X+X.T)

def compute_Z(input):return torch.pow(input,2)
    
print(torch.autograd.functional.jacobian(compute_Z, Y))

Unfortunately this is still outputting the same result without the symmetry :frowning:

Hi Bruce!

Try moving your symmetrization step inside of compute_Z():

>>> import torch
>>> print (torch.__version__)
2.1.2
>>>
>>> X = torch.tensor ([[1, 3, 5], [3, 1, 7], [5, 7, 1]], dtype = torch.double)
>>>
>>> def compute_Z (input) :   # symmetrize input -- in a sense assumes that it is symmetric
...     Y = 0.5 * (input + input.T)
...     return torch.pow (Y, 2)
...
>>> dZdX = torch.autograd.functional.jacobian (compute_Z, X)
>>>
>>> print (dZdX[1, 2, 1, 2])
tensor(7., dtype=torch.float64)
>>> print (dZdX[1, 2, 2, 1])
tensor(7., dtype=torch.float64)

Does this work for you?

Best.

K. Frank

Hi K.Frank again

Unfortunately, moving the symmetrization step into the function does not help, although the output shows symmetric, the values are not correct.
This simpler test shows the problem, for any 3-by-3 tensor X, function Y is simple X itself plus the symmetrization:

X.requires_grad_(True)
def compute_Y(X):return 0.5 * (X+X.T)
print('dYdX = ', torch.autograd.functional.jacobian(compute_Y, compute_Y(X)))

The correct output with symmetric information passing down correctly would be a 4th order symmetric tensor with only 1 and zero components, however, the about code leaves 0.5 on the off-diagonal non-zero positions.

Hi Bruce!

This depends on what you think the jacobian with respect to a symmetric
matrix ought to be.

The core issue is that a 3x3 symmetric matrix has six degrees of freedom,
namely the three diagonal elements plus the three off-diagonal elements
(each of which occurs as two equal copies in the full 3x3 matrix).

In contrast, a general (that is, not necessarily symmetric) 3x3 matrix has
nine degrees of freedom – all nine elements of the matrix are independent
of one another.

Conceptually the gradient of a scalar function with respect to a 3x3
symmetric matrix also consists of six (independent) values. If you
express that gradient as a full 3x3 matrix, you might be “over-counting”
the off-diagonal elements of the gradient by a factor of two (depending
on how you define things) because there are two copies of those
off-diagonal elements.

You can view that factor of 0.5 as compensating for the over-counting.

If you don’t like that factor of 0.5, you can compute the jacobian the first
way I suggested and then symmetrize the result, or you can compute the
jacobian the second way I suggested and multiply the off-diagonal elements
by 2.0.

You can use whatever definition you want for the the off-diagonal elements
of the jacobian (as long as you use it consistently).

(My preference is to admit that a 3x3 symmetric matrix has only six
independent degrees of freedom and represent it (and the gradient)
as a triangular matrix with only six (non-zero) elements.)

Best.

K. Frank