Just an update for anyone who reads the thread in the future, as of PyTorch2,the functorch library is now included in pytorch. So you can replace functorch with torch.func, for the most part the syntax is the same except if you have an nn.Module you’ll need to create a ‘functional’ version of your model.
For example,
model = myModel(*args, **kwargs) #our network
from torch.func import vmap, jacrev, functional_call
params = dict(model.named_parameters())
inputs = torch.randn(batch_size, input_size) #random input data
def fmodel(params, inputs): #functional version of model
return functional_call(model, params, inputs)
result = vmap(jacrev(fmodel, argnums=(1)), in_dims=(None,0))(params, inputs)
The documentation for torch.func can be found here.
Also, if you’re migrating from functorch to torch.func they have a documentation page on the changes between them here.