Just an update for anyone who reads the thread in the future, as of PyTorch2,the functorch
library is now included in pytorch. So you can replace functorch
with torch.func
, for the most part the syntax is the same except if you have an nn.Module
you’ll need to create a ‘functional’ version of your model.
For example,
model = myModel(*args, **kwargs) #our network
from torch.func import vmap, jacrev, functional_call
params = dict(model.named_parameters())
inputs = torch.randn(batch_size, input_size) #random input data
def fmodel(params, inputs): #functional version of model
return functional_call(model, params, inputs)
result = vmap(jacrev(fmodel, argnums=(1)), in_dims=(None,0))(params, inputs)
The documentation for torch.func
can be found here.
Also, if you’re migrating from functorch
to torch.func
they have a documentation page on the changes between them here.