Autograd matrix to matrix function

Hi!

I hope I’m in the right place to ask this question. I’m new to actually caring about how autograd works, so I’m trying to understand how I can define a new autograd function in the case where I map a matrix to a scalar using intermediate matrix transformations. I’m mainly wondering what mathematically is going on. I’m looking at PyTorch: Defining New autograd Functions — PyTorch Tutorials 1.11.0+cu102 documentation.

I realize I can using autograd calculate the gradient of something like energy below,

def mat2mat_function(matrix):
  return(matrix@ matrix)

def energy(matrix):
  return(torch.trace(matrix_sq.T @ matrix_sq)) 

My intermediate function mat2mat_function in the above example I input a matrix and am returned is square. It is not clear to me what is going on auto-grad wise. What is meant by applying the chain rule to a matrix-to-matrix function? In this case, I can of course guess that I mean that I multiply the current gradient with 2X, as the function is X^2. But what if my matrix-to-matrix function is much more complicated? I looked into Gateux derivatives, but they give operators between the spaces and not directly gradients. So what is going on, and how would I implement my own custom matrix-to-matrix autograd function?

I can, following the example, define a class

class matrix_function(torch.autograd.Function):
@staticmethod
    def forward(ctx, input):
      
        ctx.save_for_backward(input)
        return input @ input 

    @staticmethod
    def backward(ctx, grad_output):
        
        input, = ctx.saved_tensors
        return grad_output * 2*input #correct ??????? 

If the Gateaux derivative of my intermediate matrix function was some much more complicated thing, which only when evalualted would return a matrix, what do I do then? Imagine that I’m for instance solving an equation A(X) = Y, and am inputting X to obtain Y?

  1. Your function only works for square matrices, I don’t know if that is on purpose.

  2. Autograd in the usual backward AD mode computes a gradient or more generally a vector jacobian product. The grad_output is the vector and the backward should return this grad output matrix-multiplied with the jacobian of the function. I would expect the gradient to be something like grad_out @ input.T + input.T @ grad_out.

  3. You can check with torch.autograd.gradcheck(matrix_function.apply, torch.randn(5, 5, requires_grad=True, dtype=torch.double)) or so.

Best regards

Thomas