Can the computation of n-th order derivatives be sped up?

Hey,

for my current project I’m using PyTorch to evaluate PDEs and need to compute up to fourth order derivatives from the network output w.r.t. its input. Up to right now I’m computing them very naively by just passing the previous derivatives into the next grad operator, like so:

net = torch.nn.Sequential(
  torch.nn.Linear(2,50),
  torch.nn.Tanh(),
  torch.nn.Linear(50,1)
)

x = torch.randn(100,2).requires_grad_(True)

for ep in range(epochs):
  optimizer.zero_grad()
  
  z = net(x)
  dzdx = torch.autograd.grad(z, x, grad_outputs=torch.ones_like(x[:,0:1]), create_graph=True)[0][:,0:1]
  dzdxx = torch.autograd.grad(dzdx, x, grad_outputs=torch.ones_like(x[:,0:1]), create_graph=True)[0][:,0:1]
  dzdxxx = torch.autograd.grad(dzdxx, x, grad_outputs=torch.ones_like(x[:,0:1]), create_graph=True)[0][:,0:1]
  dzdxxxx = torch.autograd.grad(dzdxxx, x, grad_outputs=torch.ones_like(x[:,0:1]), create_graph=True)[0][:,0:1]

  loss = some_comb_of_derivatives()
  loss.backward()
  optimizer.step()

This however is really slow and consumes a lot of memory (which is also expected when creating the graph for each differentiation), especially for bigger networks and batch sizes. Is there something I’m doing wrong conceptually or is there a possibility to speed up the computation of the derivatives?