Vmap a nn.Module

Hi, I found the new functorch.vmap feature to be very useful,I am trying to use vmap to parallelize multiple decoders with the same input and output shapes. There are some examples of using vmap to parallelize nn.Module in the doc, but none of them deal with the use of vmap within classes. This is my initial idea of implementing vmap inside a class, can you help me with it? Thanks!

Without vmap it looks like this, which is very time-consuming when num_mod is large.

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.num_mod=6
        
        self.encoder=nn.RNN()
        decoder=[nn.Linear(10,1) for i in range(self.num_mod)]
        self.decoder = nn.ModuleList(decoder)
        
    def forward(self,x):
        x=self.encoder(x)
        
        x=x.repeat(self.num_mod,1)
        out=[]
        for i in range(self.num_mod):
            out.append(self.decoder[i](x[i]))
            
        return torch.stach(out,dim=1)      

When I know that vmap can also apply to nn.Module , I try to achieve something like this: can you help me with it

from functorch import make_functional,vmap

class Net(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.num_mod=6
        
        self.encoder=nn.RNN()
        
        func, params = make_functional(nn.Linear(10,1))
        self.decoder=vmap(func)
        
    def forward(self,x):
        x=self.encoder(x)
        x=x.repeat(self.num_mod,1)
            
        return self.decoder(x)