Hello,
I am benchmarking the different ways to perform Hessian-vector products with the API torch.func
. I am a bit surprised by the results I get since the forward-over-reverse
fashion seems significantly longer than the reverse-over-reverse
way. Am I doing something wrong?
import torch
from torchvision.models import resnet18
from functorch.experimental import replace_all_batch_norm_modules_
from time import perf_counter
batch_size = 4
num_classes = 1000
model = replace_all_batch_norm_modules_(resnet18()).cuda()
batch = {
'images': torch.randn(batch_size, 3, 224, 224).cuda(),
'labels': torch.randint(0, num_classes, (batch_size,),).cuda()
}
def loss(params):
"""loss function used for training."""
logits = torch.func.functional_call(model, params,
(batch['images'], ))
res = torch.nn.functional.cross_entropy(logits, batch['labels'])
return res
def hvp_forward_over_reverse(x, v):
return torch.func.jvp(torch.func.grad(loss), (x, ), (v, ))[1]
def hvp_reverse_over_forward(x, v):
def jvp_fun(x, v):
return torch.func.jvp(loss, (x, ), (v, ))[1]
return torch.func.grad(jvp_fun)(x, v)
def hvp_reverse_over_reverse(x, v):
grad_fun = torch.func.grad(loss)
return torch.func.grad(lambda y: sum(
torch.dot(a.ravel(), b.ravel())
for a, b in zip(grad_fun(y).values(), v.values()))
)(x)
if __name__ == "__main__":
params = dict(model.named_parameters())
v = torch.func.grad(loss)(params)
print("Forward-over-reverse")
start = perf_counter()
hvp_forward_over_reverse(params, v)
print(perf_counter() - start)
print("Reverse-over-forward")
start = perf_counter()
hvp_reverse_over_forward(params, v)
print(perf_counter() - start)
print("Reverse-over-reverse")
start = perf_counter()
hvp_reverse_over_reverse(params, v)
print(perf_counter() - start)
Output:
Forward-over-reverse
0.27017239201813936
Reverse-over-forward
0.043903919868171215
Reverse-over-reverse
0.09637519344687462
Best,