Forward-over-reverse HVP slower than reverse-over-reverse HVP with torch.func API

Hello,
I am benchmarking the different ways to perform Hessian-vector products with the API torch.func. I am a bit surprised by the results I get since the forward-over-reverse fashion seems significantly longer than the reverse-over-reverse way. Am I doing something wrong?

import torch
from torchvision.models import resnet18
from functorch.experimental import replace_all_batch_norm_modules_


from time import perf_counter

batch_size = 4
num_classes = 1000
model = replace_all_batch_norm_modules_(resnet18()).cuda()

batch = {
    'images': torch.randn(batch_size, 3, 224, 224).cuda(),
    'labels': torch.randint(0, num_classes, (batch_size,),).cuda()
}


def loss(params):
    """loss function used for training."""
    logits = torch.func.functional_call(model, params,
                                        (batch['images'], ))
    res = torch.nn.functional.cross_entropy(logits, batch['labels'])
    return res


def hvp_forward_over_reverse(x, v):
    return torch.func.jvp(torch.func.grad(loss), (x, ), (v, ))[1]


def hvp_reverse_over_forward(x, v):

    def jvp_fun(x, v):
        return torch.func.jvp(loss, (x, ), (v, ))[1]

    return torch.func.grad(jvp_fun)(x, v)


def hvp_reverse_over_reverse(x, v):

    grad_fun = torch.func.grad(loss)

    return torch.func.grad(lambda y: sum(
        torch.dot(a.ravel(), b.ravel())
        for a, b in zip(grad_fun(y).values(), v.values()))
    )(x)


if __name__ == "__main__":
    params = dict(model.named_parameters())
    v = torch.func.grad(loss)(params)

    print("Forward-over-reverse")
    start = perf_counter()
    hvp_forward_over_reverse(params, v)
    print(perf_counter() - start)

    print("Reverse-over-forward")
    start = perf_counter()
    hvp_reverse_over_forward(params, v)
    print(perf_counter() - start)

    print("Reverse-over-reverse")
    start = perf_counter()
    hvp_reverse_over_reverse(params, v)
    print(perf_counter() - start)

Output:

Forward-over-reverse
0.27017239201813936
Reverse-over-forward
0.043903919868171215
Reverse-over-reverse
0.09637519344687462

Best,

You probably want to use torch.cuda.sychronize() and average your measurements over many runs to reduce variance. Forward-over-reverse is also the first hvp you are computing, so maybe cuda is still warming up/mallocing for the first time.