Essentially, I need to perform some standard prediction task, and at the same time, I also need to calculate the Jacobian of the scores (defined by the output of a large network) with respect to the input images. The code snippet looks something like the following

```
# So this pretrain_net is pretty big
# This part is standard neural net training
pretrain_net = torch.load("very_large_pretrain_net.pt")
x = torch.rand(B,3,224,224) # some batched images
output = pretrain_net(x)
loss = some_loss(output, ground_truth)
loss.backward()
optim.step()
# But beyond the standard training, I also need to compute the Jacobian of
# the network output with respect to the input images
def calc_score(input_im):
input_im = input_im.unsqueeze(1)
output = pretrain_net(input_im) # say output is of shape Bx1024
score = torch.mean(output,-1)
return score
# Calculate the sample-wise Jacobian
jacobian = vmap(jacrev(calc_score))(x)
```

I notice that the computation above is fairly wasteful. I essentially forward through the same network and backprop through this very large pretrained network twice, once in the standard prediction task and once (maybe multiple times even? as I’m assuming the vmap is storing the gradient results for each sample) during the Jacobian calculation stage. As a result, I quickly ran out of GPU memory. Do you have any suggestion on how to save some memory in this case? The prediction task and the Jacobian calculation task share the same network, so multiple backprop & and storing multiple copies of this gradient information is definitely not necessary.