If you’re hitting an OOM error, you can always ‘chunk’ the vmap operation. So send multiple mini-batches and then just concatenate the results.
In the example you shared, a ‘chunk’ version would be something like this,
import torch
from functorch import make_functional, vmap, grad
model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)
func_model, params = make_functional(model)
def compute_loss(params, data, targets):
preds = func_model(params, data)
return torch.mean((preds - targets) ** 2)
per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)
out_full = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)
nchunks=8 #number of chunks
nparams=len(params) #number of params, 2 in this case (weight & bias)
#perform vmap in chunks
out = [vmap(grad(compute_loss), in_dims=(None, 0, 0))(params, data_chunk, targets_chunk) for data_chunk, targets_chunk in zip(data.chunk(nchunks), targets.chunk(nchunks))]
#recursively flatten list
out = [item for sublist in out for item in sublist]
#re-map to correct shape and concatenate
out_chunk = [torch.cat([out[nparams*chunk+p] for chunk in range(nchunks)], dim=0) for p in range(nparams)]
#check both methods for completeness
for i in range(2):
print("Param: %i Match? %s" % (i, torch.allclose(out_full[i], out_chunk[i])))
"""
Returns
Param: 0 Match? True
Param: 1 Match? True
"""
If you still have an issue, you can open an issue on functorch’s github repo here. There is some development of to ‘chunk’ vmap as seen in issue #680 but it’s still under development I believe!