Computing derivative of NN output wrt to input

I’m currently trying to implement Physics Informed Neural Networks to solve ODEs. As part of it, I need to compute the derivative of the NN wrt to the inputs. So far I have used the code

dy_dt = torch.autograd.grad(y,t, torch.ones_like(y), create_graph=True)[0]

Where t is some tensor with shape [N,1].
This worked for learning scalar valued functions where the NN had 1 input node and one output node returning a tensor of shape [N,1]. However now I want to try and solve planar systems so my NN should have 1 input node and 2 output nodes. When I create a training tensor which has shape [N,1], my NN output has shape [N,2] which seems fine, but the resulting derivative has shape [N,1]. I also tried to use

torch.autograd.functional.jacobian(model, t)

but this returned a tensor of shape [N,2,N,1]. I think it should be possible to access the right derivatives from this but I was wondering if there was a to return a tensor of shape [N,2], so each row is the derivative of the output of a specific input wrt to that input.

Hi Victor!

I assume that N is the batch dimension of your input tensor.

I also assume that N is the batch dimension of your output tensor and
also that your model doesn’t mix batch elements together. That is, that
model (t)[i] only depends on t[i].

Based on my assumption that your model doesn’t mix together batch
elements, the “batch-off-diagonal” elements of the jacobian will be zero.

In your case, it might be simplest simply to extract the non-zero elements
of the jacobian (which will be the derivatives of your two scalar outputs
with respect to your single scalar input on a per-batch-element basis).

If you do this, however, autograd will be unnecessarily computing the
zero, batch-off-diagonal elements of the full jacobian, which could matter
for your performance if N is large.

Yes. Probably the best way would be to apply vmap to jacfwd.* This will
apply jacfwd to model on a per-batch-element basis, giving your desired
derivatives, without running a python loop (which would dramatically slow
things down).

Here is a script that illustrates these two approaches:

import torch
print (torch.__version__)

_ = torch.manual_seed (2024)

t = torch.randn (3, 1)
model = torch.nn.Sequential (torch.nn.Linear (1, 2), torch.nn.Tanh(), torch.nn.Linear (2, 2))

jac = torch.autograd.functional.jacobian (model, t)    # "batch-off-diagonal" elements are zero
print ('jac = ...')
print (jac)

# equivalently, could have used
#   jac = torch.func.jacfwd (model) (t)
# or
#   jac = torch.func.jacrev (model) (t)

grad = jac.squeeze().diagonal (dim1 = 0, dim2 = 2).T   # extract non-zero, "diagonal" elements
print ('grad = ...')
print (grad)

gradB = torch.func.vmap (torch.func.jacfwd (model), in_dims = (0,)) (t).squeeze()
print ('gradB = ...')
print (gradB)

print ('torch.equal (grad, gradB):', torch.equal (grad, gradB))

And here is its output:

2.3.1
jac = ...
tensor([[[[ 0.1516],
          [ 0.0000],
          [ 0.0000]],

         [[-0.0771],
          [ 0.0000],
          [ 0.0000]]],


        [[[ 0.0000],
          [ 0.0703],
          [ 0.0000]],

         [[ 0.0000],
          [-0.0394],
          [ 0.0000]]],


        [[[ 0.0000],
          [ 0.0000],
          [ 0.1426]],

         [[ 0.0000],
          [ 0.0000],
          [-0.0732]]]])
grad = ...
tensor([[ 0.1516, -0.0771],
        [ 0.0703, -0.0394],
        [ 0.1426, -0.0732]])
gradB = ...
tensor([[ 0.1516, -0.0771],
        [ 0.0703, -0.0394],
        [ 0.1426, -0.0732]], grad_fn=<SqueezeBackward0>)
torch.equal (grad, gradB): True

*) I used jacfwd() rather than jacrev() because you have fewer inputs
(per batch element), namely one, than outputs, namely two (although with
only two output elements, this won’t matter much).

Best.

K. Frank

1 Like