I’m trying to use torch.compile
around a vmapped ensemble but it fails with Batching rule not implemented for aten::t
. A minimal example:
import torch
from torch import nn
from torch._dynamo import allow_in_graph
from functools import wraps
from torch.func import stack_module_state, functional_call
import functorch
import copy
def traceable(f):
f = allow_in_graph(f)
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
x = torch.randn(3)
class Net(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fc = nn.Linear(3, 5)
def forward(self, x):
return self.fc(x)
models = [Net() for _ in range(5)]
params, buffers = stack_module_state(models)
# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
vmapped_fmodel = functorch.vmap(fmodel, in_dims=(0, 0, None))
# Works
vmapped_fmodel(params, buffers, x)
# Doesn't work
torch.compile(traceable(vmapped_fmodel))(params, buffers, x)
I already opened a Github issue regarding aten::t
. Here I’m just wondering if there is any simple trick to get around this issue. Thanks!