Hi all,
TL;DR - How can I compute matrix-vector products of a neural network without fully materializing the Jacobian of the network?
I’m trying to figure out how to compute matrix-vector products efficiently via the use of jacobian-vector and vector-jacobian products (which allows for the computation of matrix-vector products without ever materializing the full matrix).
Here’s an example of a so-called ‘dense’ approach below.
import torch
_ = torch.manual_seed(0)
from torch import nn, Tensor
from torch.func import functional_call, vmap, jacrev, jvp, vjp
from torch.utils._pytree import tree_flatten
device=torch.device('cpu')
dtype=torch.float32
factory_kwargs = {'device':device,'dtype':dtype}
class Model(nn.Module):
def __init__(self, H: int) -> None:
super(Model, self).__init__()
self.fc1 = nn.Linear(4,H)
self.fc2 = nn.Linear(H,H)
self.fc3 = nn.Linear(H,1)
self.act_func = nn.Tanh()
def forward(self, x):
x=self.fc1(x)
x=self.act_func(x)
x=self.fc2(x)
x=self.act_func(x)
x=self.fc3(x)
return x.squeeze(-1)
model = Model(H=64).to(**factory_kwargs) #our model
x = torch.randn(4096, 4, **factory_kwargs) #functional inputs
params = dict(model.named_parameters())
#funtionalize call of network
def fcall(params, x):
return functional_call(module=model, parameter_and_buffer_dicts=params, args=(x))
def tree_to_vector(pytree) -> Tensor:
v, _ = tree_flatten(pytree) #flatten all leaves of pytree to list
return torch.cat([x.reshape(-1) for x in v] , dim=0)
v1 = {k: torch.randn_like(p) for k, p in params.items()} #random vector (for our matrix-vector product)
vector = tree_to_vector(v1)
jac_wrt_params_pytree = vmap(jacrev(fcall, argnums=0), in_dims=(None,0))(params, x)
jac_wrt_params = vmap(tree_to_vector, in_dims=(0))(jac_wrt_params_pytree) #flatten to vector
FIM = jac_wrt_params.T @ jac_wrt_params #dense FIM shape [4545,4545]
matrix_vec = FIM @ vector
print(matrix_vec.shape) #shape [4545,]
However, it seems that you can actually compute the matrix-vector product without ever having to fully materialize the jacobian directly by composing torch.func.vjp
and torch.func.jvp
calls.
How exactly could this be done with the example shown above?
Of course, I could exploit the assoctivity of the matrix @ vector
via an opt_einsum
call like torch.einsum('bi,bj,j->i')
, however that still requires me to fully-materialize the Jacobian of my network (which is impractical for larger networks).
Any help or ideas on how to solve this problem are welcome!