`torch.compile` a vmapped ensemble

I’m trying to use torch.compile around a vmapped ensemble but it fails with Batching rule not implemented for aten::t. A minimal example:

import torch
from torch import nn
from torch._dynamo import allow_in_graph
from functools import wraps
from torch.func import stack_module_state, functional_call
import functorch
import copy

def traceable(f):
    f = allow_in_graph(f)

    def wrapper(*args, **kwargs):
        return f(*args, **kwargs)

    return wrapper

x = torch.randn(3)

class Net(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.fc = nn.Linear(3, 5)

    def forward(self, x):
        return self.fc(x)

models = [Net() for _ in range(5)]
params, buffers = stack_module_state(models)

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))
vmapped_fmodel = functorch.vmap(fmodel, in_dims=(0, 0, None))

# Works
vmapped_fmodel(params, buffers, x)

# Doesn't work
torch.compile(traceable(vmapped_fmodel))(params, buffers, x)

I already opened a Github issue regarding aten::t. Here I’m just wondering if there is any simple trick to get around this issue. Thanks!