I’m quite sure I did something wrong, but I’m not really sure what went wrong.
Basically there are no grads on the parameters after the update (optimizer.step()) and the loss doesn’t minimize ( I test it using a trivial mapping from a x ~ N(0,1) to y ~ N(0,1) where both are fixed before training)
the code I use is:
import torch
import torch.nn as nn
from torch import vmap
from architectures import mlp
from torch.func import stack_module_state,functional_call
import copy
from tqdm import tqdm
from torch import optim
class network(nn.Module):
def __init__(self,embedding_size:int,num_modules:int):
super().__init__()
self.ensamble_modules = nn.ModuleList([mlp([embedding_size,embedding_size]) for _ in range(num_modules)])
def forward(self,x:torch.Tensor) -> torch.Tensor:
base_model = copy.deepcopy(self.ensamble_modules[0]).to(device='meta')
ensemble_params, ensemble_buffers = stack_module_state(self.ensamble_modules)
ensemble_call = lambda ensemble_params, ensemble_buffers,x: functional_call(base_model,(ensemble_params, ensemble_buffers),x)
return vmap(ensemble_call)(ensemble_params, ensemble_buffers,x)
if __name__ == "__main__":
BATCH_SIZE = 1
EMBEDDING_SIZE = 10
NUM_MODULES = 4
x = torch.randn(NUM_MODULES,BATCH_SIZE,EMBEDDING_SIZE)
y = torch.randn(NUM_MODULES,BATCH_SIZE,EMBEDDING_SIZE)
net = network(EMBEDDING_SIZE,NUM_MODULES)
optimizer = optim.Adam(net.parameters(),lr=0.001)
for _ in tqdm(range(100000),ascii=True):
optimizer.zero_grad()
y_pred = net(x)
loss = nn.MSELoss()(y_pred,y)
loss.backward()
print(loss.item())
optimizer.step()