I’m trying to calculate a gradient of a function that is A*A*X*W0*W1 w.r.t A in PyTorch. I have calculated the gradient manually but I’m not getting the same answer as the pytorch’s auto grad function. Can anyone help me with what I did wrong?

The problem most likely lies with the torch.eye(rows, cols) that you use.
Keep in mind that the backward pass computes a vector jacobian product. So you need to make sure that the vector you backprop matches this identity matrix here.

I’m getting the row and col for the identity matrix by rows, cols = AAXW0W1.size() so I thought it’s matching the shape of vector I’m backpropagating. Am I wrong?

Also I was comparing the result with this function

I don’t think these are doing the same computation.
You can see this by changing the shape of you matrices so that they’re not all the same (but still match the mm) and you’ll see that your function does not return something on the shape of A.