How to perform matrix-vector products efficiently with jvp and vjp calls?

Hi all,

TL;DR - How can I compute matrix-vector products of a neural network without fully materializing the Jacobian of the network?

I’m trying to figure out how to compute matrix-vector products efficiently via the use of jacobian-vector and vector-jacobian products (which allows for the computation of matrix-vector products without ever materializing the full matrix).

Here’s an example of a so-called ‘dense’ approach below.

import torch
_ = torch.manual_seed(0)
from torch import nn, Tensor
from torch.func import functional_call, vmap, jacrev, jvp, vjp
from torch.utils._pytree import tree_flatten

device=torch.device('cpu')
dtype=torch.float32

factory_kwargs = {'device':device,'dtype':dtype}

class Model(nn.Module):
    def __init__(self, H: int) -> None:
        super(Model, self).__init__()

        self.fc1 = nn.Linear(4,H)
        self.fc2 = nn.Linear(H,H)
        self.fc3 = nn.Linear(H,1)
        self.act_func = nn.Tanh()

    def forward(self, x):
        x=self.fc1(x)
        x=self.act_func(x)
        x=self.fc2(x)
        x=self.act_func(x)
        x=self.fc3(x)
        return x.squeeze(-1)

model = Model(H=64).to(**factory_kwargs) #our model

x = torch.randn(4096, 4, **factory_kwargs) #functional inputs 
params = dict(model.named_parameters())

#funtionalize call of network
def fcall(params, x):
    return functional_call(module=model, parameter_and_buffer_dicts=params, args=(x))

def tree_to_vector(pytree) -> Tensor:
    v, _ = tree_flatten(pytree) #flatten all leaves of pytree to list 
    return torch.cat([x.reshape(-1) for x in v] , dim=0)

v1 = {k: torch.randn_like(p) for k, p in params.items()} #random vector (for our matrix-vector product)

vector = tree_to_vector(v1)

jac_wrt_params_pytree = vmap(jacrev(fcall, argnums=0), in_dims=(None,0))(params, x)
jac_wrt_params = vmap(tree_to_vector, in_dims=(0))(jac_wrt_params_pytree) #flatten to vector

FIM = jac_wrt_params.T @ jac_wrt_params #dense FIM shape [4545,4545]

matrix_vec = FIM @ vector
print(matrix_vec.shape) #shape [4545,]

However, it seems that you can actually compute the matrix-vector product without ever having to fully materialize the jacobian directly by composing torch.func.vjp and torch.func.jvp calls.

How exactly could this be done with the example shown above?

Of course, I could exploit the assoctivity of the matrix @ vector via an opt_einsum call like torch.einsum('bi,bj,j->i'), however that still requires me to fully-materialize the Jacobian of my network (which is impractical for larger networks).

Any help or ideas on how to solve this problem are welcome!

Here’s the solution, by composing a matrix-vector product as jvp and vjp calls you can significantly lower the walltime. The example script can show (and benchmark this) below.

When profiling the two functions matvec_dense and matvec_onthefly you get the

print("Matrix-vector product: Dense")
print(prof_dense.key_averages().table(sort_by=sort_by, row_limit=10))
#Self CPU time total: 48.244ms
#Self CUDA time total: 49.134ms
print("Matrix-vector product: On-The-Fly")
print(prof_otf.key_averages().table(sort_by=sort_by, row_limit=10))
#Self CPU time total: 1.237ms
#Self CUDA time total: 1.405ms

which shows a 39x and 35x speed-up on the CPU and GPU, respectively, with near-identical outputs,

out_dense = matvec_dense(vector=v1)
out_otf = matvec_onthefly(vector=v1)

print(out_dense)
print(out_otf)
"""
#returns
tensor([   95.3212,   -42.1294,  -179.3537,  ..., -1032.9551,  1085.7876,
         6401.7373], device='cuda:0', grad_fn=<MvBackward0>)
tensor([   95.3214,   -42.1296,  -179.3538,  ..., -1032.9546,  1085.7891,
         6401.7402], device='cuda:0')
"""

Here’s the script,

import torch
_ = torch.manual_seed(0)
from torch import nn, Tensor
from torch.func import functional_call, vmap, jacrev, jvp, vjp
from torch.utils._pytree import tree_flatten

device=torch.device('cuda')
dtype=torch.float32

factory_kwargs = {'device':device,'dtype':dtype}

class Model(nn.Module):
    def __init__(self, H: int) -> None:
        super(Model, self).__init__()

        self.fc1 = nn.Linear(4,H)
        self.fc2 = nn.Linear(H,H)
        self.fc3 = nn.Linear(H,1)
        self.act_func = nn.Tanh()

    def forward(self, x):
        x=self.fc1(x)
        x=self.act_func(x)
        x=self.fc2(x)
        x=self.act_func(x)
        x=self.fc3(x)
        return x.squeeze(-1)

model = Model(H=64).to(**factory_kwargs) #our model

x = torch.randn(4096, 4, **factory_kwargs) #functional inputs 
params = dict(model.named_parameters())

#funtionalize call of network
def fcall(params, x):
    return functional_call(module=model, parameter_and_buffer_dicts=params, args=(x))

def tree_to_vector(pytree) -> Tensor:
    v, _ = tree_flatten(pytree) #flatten all leaves of pytree to list 
    return torch.cat([x.reshape(-1) for x in v] , dim=0)

v1 = {k: torch.randn_like(p) for k, p in params.items()} #random vector (for our matrix-vector product)

def matvec_dense(vector):

    vector = tree_to_vector(vector)

    jac_wrt_params_pytree = vmap(jacrev(fcall, argnums=0), in_dims=(None,0))(params, x)
    jac_wrt_params = vmap(tree_to_vector, in_dims=(0))(jac_wrt_params_pytree) #flatten to vector

    FIM = jac_wrt_params.T @ jac_wrt_params #dense FIM shape [4545,4545]

    matrix_vec = FIM @ vector #shape [4545,]
    return matrix_vec

from torch import func

def matvec_onthefly(vector):
    
    def fun(params):
        return functional_call(module=model, parameter_and_buffer_dicts=params, args=(x))

    _, jvp_fn = torch.func.linearize(fun, params)

    _, vjp_fn = vjp(jvp_fn, vector)
    jac_vector = jvp_fn(vector)

    res = vjp_fn(jac_vector)
    return tree_to_vector(res)


out_dense = matvec_dense(vector=v1)
out_otf = matvec_onthefly(vector=v1)

print(out_dense)
print(out_otf)
"""
#returns
tensor([   95.3212,   -42.1294,  -179.3537,  ..., -1032.9551,  1085.7876,
         6401.7373], device='cuda:0', grad_fn=<MvBackward0>)
tensor([   95.3214,   -42.1296,  -179.3538,  ..., -1032.9546,  1085.7891,
         6401.7402], device='cuda:0')
"""

from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CUDA], record_shapes=False, profile_memory=True, with_flops=True) as prof_dense:
    delta_dense = matvec_dense(v1)

with profile(activities=[ProfilerActivity.CUDA], record_shapes=False, profile_memory=True, with_flops=True) as prof_otf:
    delta_otf = matvec_onthefly(v1)

sort_by = "cuda_memory_usage" #"cpu_memory_usage"

print("Matrix-vector product: Dense")
print(prof_dense.key_averages().table(sort_by=sort_by, row_limit=10))
#Self CPU time total: 48.244ms
#Self CUDA time total: 49.134ms
print("Matrix-vector product: On-The-Fly")
print(prof_otf.key_averages().table(sort_by=sort_by, row_limit=10))
#Self CPU time total: 1.237ms
#Self CUDA time total: 1.405ms

print('delta_dense: ',delta_dense)
print('delta_otf: ',delta_otf)
"""
#returns
delta_dense:  tensor([   95.3212,   -42.1294,  -179.3537,  ..., -1032.9551,  1085.7876,
         6401.7373], device='cuda:0', grad_fn=<MvBackward0>)
delta_otf:  tensor([   95.3214,   -42.1296,  -179.3538,  ..., -1032.9546,  1085.7891,
         6401.7402], device='cuda:0')
"""