I’m currently trying to implement Physics Informed Neural Networks to solve ODEs. As part of it, I need to compute the derivative of the NN wrt to the inputs. So far I have used the code
Where t is some tensor with shape [N,1].
This worked for learning scalar valued functions where the NN had 1 input node and one output node returning a tensor of shape [N,1]. However now I want to try and solve planar systems so my NN should have 1 input node and 2 output nodes. When I create a training tensor which has shape [N,1], my NN output has shape [N,2] which seems fine, but the resulting derivative has shape [N,1]. I also tried to use
torch.autograd.functional.jacobian(model, t)
but this returned a tensor of shape [N,2,N,1]. I think it should be possible to access the right derivatives from this but I was wondering if there was a to return a tensor of shape [N,2], so each row is the derivative of the output of a specific input wrt to that input.
I assume that N is the batch dimension of your input tensor.
I also assume that N is the batch dimension of your output tensor and
also that your model doesn’t mix batch elements together. That is, that model (t)[i] only depends on t[i].
Based on my assumption that your model doesn’t mix together batch
elements, the “batch-off-diagonal” elements of the jacobian will be zero.
In your case, it might be simplest simply to extract the non-zero elements
of the jacobian (which will be the derivatives of your two scalar outputs
with respect to your single scalar input on a per-batch-element basis).
If you do this, however, autograd will be unnecessarily computing the
zero, batch-off-diagonal elements of the full jacobian, which could matter
for your performance if N is large.
Yes. Probably the best way would be to apply vmap to jacfwd.* This will
apply jacfwd to model on a per-batch-element basis, giving your desired
derivatives, without running a python loop (which would dramatically slow
things down).
Here is a script that illustrates these two approaches:
import torch
print (torch.__version__)
_ = torch.manual_seed (2024)
t = torch.randn (3, 1)
model = torch.nn.Sequential (torch.nn.Linear (1, 2), torch.nn.Tanh(), torch.nn.Linear (2, 2))
jac = torch.autograd.functional.jacobian (model, t) # "batch-off-diagonal" elements are zero
print ('jac = ...')
print (jac)
# equivalently, could have used
# jac = torch.func.jacfwd (model) (t)
# or
# jac = torch.func.jacrev (model) (t)
grad = jac.squeeze().diagonal (dim1 = 0, dim2 = 2).T # extract non-zero, "diagonal" elements
print ('grad = ...')
print (grad)
gradB = torch.func.vmap (torch.func.jacfwd (model), in_dims = (0,)) (t).squeeze()
print ('gradB = ...')
print (gradB)
print ('torch.equal (grad, gradB):', torch.equal (grad, gradB))
*) I used jacfwd() rather than jacrev() because you have fewer inputs
(per batch element), namely one, than outputs, namely two (although with
only two output elements, this won’t matter much).