Vmap mlp ensemble zero grads after update

I’m quite sure I did something wrong, but I’m not really sure what went wrong.

Basically there are no grads on the parameters after the update (optimizer.step()) and the loss doesn’t minimize ( I test it using a trivial mapping from a x ~ N(0,1) to y ~ N(0,1) where both are fixed before training)
the code I use is:

import torch 
import torch.nn as nn 
from torch import vmap 
from architectures import mlp 
from torch.func import stack_module_state,functional_call
import copy 
from tqdm import tqdm 
from torch import optim 


class network(nn.Module):
    def __init__(self,embedding_size:int,num_modules:int):
        super().__init__()    
        self.ensamble_modules = nn.ModuleList([mlp([embedding_size,embedding_size]) for _ in range(num_modules)])

    def forward(self,x:torch.Tensor) -> torch.Tensor:
        base_model = copy.deepcopy(self.ensamble_modules[0]).to(device='meta')

        ensemble_params, ensemble_buffers = stack_module_state(self.ensamble_modules)
        ensemble_call = lambda ensemble_params, ensemble_buffers,x: functional_call(base_model,(ensemble_params, ensemble_buffers),x)

        return vmap(ensemble_call)(ensemble_params, ensemble_buffers,x) 


if __name__ == "__main__":
    BATCH_SIZE = 1
    EMBEDDING_SIZE = 10
    NUM_MODULES = 4
    x = torch.randn(NUM_MODULES,BATCH_SIZE,EMBEDDING_SIZE)
    y = torch.randn(NUM_MODULES,BATCH_SIZE,EMBEDDING_SIZE)
    net = network(EMBEDDING_SIZE,NUM_MODULES)

    optimizer = optim.Adam(net.parameters(),lr=0.001)
    
    for _ in tqdm(range(100000),ascii=True):
        optimizer.zero_grad()
        y_pred =  net(x)
        loss = nn.MSELoss()(y_pred,y)
        loss.backward()
        print(loss.item())
        optimizer.step()

architectures.py is

import torch
import torch.nn as nn
import math

class mlp(nn.Module):
    def __init__(self,struct:list,act_f:nn.Module = nn.ReLU,add_layer_norm:bool=True) -> None:
        super().__init__()
        self.struct = struct 
        self.act_f = act_f
        self.add_layer_norm = add_layer_norm
        self.layers = self._build()

    def _build(self) -> nn.Sequential:
        layers = []
        for index,(i,j) in enumerate(zip(self.struct[:-1],self.struct[1:])):
            layers.append(nn.Linear(i,j))
            if(not (index == len(self.struct)-2)):#if not last layer
                if(self.add_layer_norm):
                    layers.append(nn.LayerNorm(j))
                layers.append(self.act_f())    
        return nn.Sequential(*layers)

    def forward(self,x:torch.Tensor) -> torch.Tensor:
        return self.layers(x)
    

I’m an idiot, I just see in the documentation this:

Given a list of M nn.Modules of the same class, returns two dictionaries that stack all of their parameters and buffers together, indexed by name. The stacked parameters are optimizable (i.e. they are new leaf nodes in the autograd history that are unrelated to the original parameters and can be passed directly to an optimizer).