Using the torch.func namespace requires you to re-write how you compute the gradients entirely, you can’t mix and match between torch.autograd and torch.func, especially when using torch.func.vmap (at least to my knowledge)
I have some previous examples on the forums of how to compute gradients with a ‘functional’ approach here: