Obtaining derivatives with PyTorch

I am trying to obtain hard-to-symbolically-compute derivatives of matrices, so I thought that I can use PyTorch’s autograd features.

I have a vector y=f(x) where y is a vector (n,1) and x is a matrix of (p,q). y is composed of a lot of operations involving x (multiplications, vertical stacking, etc …).

I’m trying to find the gradient of dy/dx (which I suppose needs to have the dimensions [n by p by q]. Not exactly sure what v to write in the y.backward(v) function to get the correct results because if I use `v = torch.ones_like(y)` then I get the gradient `x.grad` as a matrix [p,q] where I would expect it to become [n,p,q] (and then again, maybe I have a different problem?).

Here’s a sample code:

``````import torch

n, p, q = 9, 2, 3
A = torch.randn(q, p)
z = torch.randn(q, 1)

f = A@x
y = None
for k in range(int(n/q)):
if(y is None):
y = f @ z
else:
y = torch.vstack( ( y, f @ z ) )

# prepare for next time step
f = (A@x) @ f

v = torch.ones_like(y)
y.backward(v, retain_graph=True)

>>> tensor([[13.4633, -3.9368,  7.8990],
[21.1977, -4.3395, -5.6655]])
``````

What is being done here is that the autograd engine expects the starting point for backward to be a scalar, and if it is not a scalar, the `gradient` parameter of `backward` is used to compute the gradient with respect to the inputs (in this case the inputs are x). When the elements of x are used in multiple parts of the graph, the total derivative is computed (the sum of the gradients for each of the uses). You can observe this as the magnitude of `x.grad` will increase the larger you make n.

If you expect these to be separate, you might want to try calling backward on individual elements of `y` and manually constructing your `[n, p, q]` gradient tensor (e.g., with `torch.stack`).

Ok, thank you, I see. So in this case, I would have a couple of follow-up questions.
If I understand you correctly, I would have to create dy/dx[k,i,j] step by step:

1. This would involve taking the gradient of y[k] with respect to every element of x[i,j].
If I define x as in the above code, x[i,j].grad is None. If I manually create x11,x12 … and then construct the x from it using stack, then I can get the gradient, but it is very gawky. Is there a way to define x as a matrix but then compute the derivative of each individual matrix item?

2. Would I then call y[k].backward() every time before looping through all (i,j) to get the gradients?

Updated code with the questions in the comments:

``````import torch

n, p, q = 9, 2, 3

# this is the gawky part
X_lst = [[x11, x12, x13], [x21,x22,x23]]

x1 = x11
x1 = torch.hstack((x1, x12))
x1 = torch.hstack((x1, x13))
x2 = x21
x2 = torch.hstack((x2, x22))
x2 = torch.hstack((x2, x23))
x  = torch.vstack((x1,x2))
# just to get the x

A = torch.randn(q, p)
z = torch.randn(q, 1)

y = None
for k in range(int(n/q)):
if(y is None):
f = A@x
y = f @ z
else:
f = A@x @ f
y = torch.vstack( ( y, f @ z ) )

# need to do this? here?
# for k in range(n):
#     y[k].backward(retain_graph=True)

dY_dX = torch.zeros((n,p,q))
for i in range(p):
for j in range(q):
for k in range(n):
# or do it here?
y[k].backward(retain_graph=True)
xx = X_lst[i][j]

# because x[i,j].grad gives me None

dY_dX[k, i, j] = dyk_dxij

``````

Is this the way to go?

I think something simpler can be done e.g.,

``````import torch

n, p, q = 9, 2, 3
A = torch.randn(q, p)
z = torch.randn(q, 1)

f = A@x
y = None
for k in range(int(n/q)):
if(y is None):                                                                                                                  y = f @ z
else:                                                                                                                           y = torch.vstack( ( y, f @ z ) )                                                                                                                                                                                                                    # prepare for next time step                                                                                                f = (A@x) @ f
dydx = list()
for i in range(n):
y[i, 0].backward(retain_graph=True)
dydx = torch.stack(dydx)
print(dydx.shape)
``````
1 Like

This works very well, thank you, but I have another follow-up question.
suppose now that y is not a vector (n), but a matrix (n,n). So now I need to calculate:

``````for k in range(n):
for l in range(n):