More efficient autograd for matrix w.r.t matrix gradient computation?

Hello ! I have two tensors K (N, N) and X (N, T, F), and I want to calculate the derivative of K[i, j] with respect to X[j, -1].
At the moment I’m using torch.autograd.grad(K[i, j], X) and getting the [j, -1]th output, but I’m wondering if it’s possible to differentiate directly with respect to X[j, -1], or even better.

The final goal is to calculate this:

torch.stack([
    torch.stack([
        torch.autograd.grad(K[i, j], X, retain_graph=True)[0][j, -1]
        for j in range(X.shape[0])
    ])
    for i in range(X.shape[0])
]).mean(1)

If you can think of a way to optimise this, please let me know!