Backpropagation with model ensembling

I need to train several neural networks with the same structure and with the same input. Training one by one takes quite a long time and I found that using model ensembling would be a good option here. However, when I try it, the models are not optimizing. I provide this simple example:

import torch as th
import torch.nn as nn
from torch.func import stack_module_state, functional_call

import sys

import copy

vectorized = False

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(2,1)
    def forward(self, x):
        return th.sigmoid(self.fc(x))


models = [Net().to("cuda") for _ in range(1)]
models = nn.ModuleList(models)

optimizer = th.optim.Adam(models.parameters(), lr=0.05)

if vectorized:

    def fmodel(params, buffers, x):
        return functional_call(base_model, (params, buffers), x)


    for epoch in range(100):
        data = th.rand(1,2) * 2 - 1
        data = data.to("cuda")

        params, buffers = stack_module_state(models)

        base_model = copy.deepcopy(models[0])
        base_model = base_model.to('meta')

        loss = th.vmap(fmodel, in_dims=(0, 0, None))(params, buffers, data)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(loss.item())

else:

    for epoch in range(100):
        data = th.rand(1,2) * 2 - 1
        data = data.to("cuda")
        for model in models:
            loss = model(data)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(loss.item())

When I set vectorized=False, the loss behaves as follows:

0.468487024307251
0.5468327403068542
0.4666518270969391
0.30017247796058655
0.42256370186805725
0.41682013869285583
0.3572870194911957
0.37354421615600586
0.39447021484375
0.3872259855270386
0.3910242021083832
0.3117615282535553
0.22395209968090057
0.28443443775177
0.26190704107284546
0.2749706506729126
0.28424859046936035
0.2088702917098999
0.29598936438560486
0.28152528405189514
0.1965540051460266
0.2668392062187195
0.2258613258600235
0.20860962569713593
0.17170515656471252
0.18912245333194733
0.18060843646526337
0.18403497338294983
0.16174425184726715
0.15166039764881134
0.15438508987426758
0.14568376541137695
0.1424552947282791
0.1341356784105301
0.13102765381336212
0.13481976091861725
0.12841664254665375
0.12801849842071533
0.11013973504304886
0.09837993234395981
0.09889375418424606
0.09404817968606949
0.11379320919513702
0.10520896315574646
0.0864768698811531
0.10952682793140411
0.08476637303829193
0.09675952792167664
0.08392863720655441
0.070079006254673
0.1028694212436676
0.0883655995130539
0.06897855550050735
0.0695270448923111
0.07020990550518036
0.08362960815429688
0.08411870151758194
0.07858551293611526
0.07350816577672958
0.07223667204380035
0.060547616332769394
0.056986112147569656
0.06019512191414833
0.05476418137550354
0.05658810958266258
0.05796775221824646
0.049511805176734924
0.04900973290205002
0.05014628544449806
0.059480972588062286
0.05562940984964371
0.04808162525296211
0.04888279363512993
0.045263487845659256
0.04549892246723175
0.0463012270629406
0.04478741064667702
0.04320802912116051
0.0414426214993
0.04260075464844704
0.04074430093169212
0.04047136381268501
0.038486938923597336
0.03848464414477348
0.03592165187001228
0.04080507159233093
0.03803026303648949
0.04044129326939583
0.04093775153160095
0.038270220160484314
0.03303438797593117
0.03105958364903927
0.03366934508085251
0.03498108685016632
0.03002573549747467
0.029859410598874092
0.030309630557894707
0.03262103721499443
0.03157965466380119
0.030938366428017616

When I set vectorized=True, the loss seems to oscillate:

0.39742761850357056
0.5150707364082336
0.33502712845802307
0.40068161487579346
0.3729161322116852
0.26942333579063416
0.3000563383102417
0.26778095960617065
0.2689218819141388
0.30592694878578186
0.3212797939777374
0.48540589213371277
0.3099740445613861
0.36055657267570496
0.37235185503959656
0.5096343159675598
0.44988110661506653
0.41553425788879395
0.4014529883861542
0.2469296157360077
0.25996655225753784
0.49726536870002747
0.3670669496059418
0.45995020866394043
0.4013332426548004
0.333894819021225
0.44986388087272644
0.3017444610595703
0.43692031502723694
0.4194122850894928
0.4580802321434021
0.2955131530761719
0.3925922214984894
0.42527011036872864
0.5408198237419128
0.4509319067001343
0.40136289596557617
0.3807404637336731
0.48967868089675903
0.27931997179985046
0.26397600769996643
0.5145940780639648
0.23822587728500366
0.29648733139038086
0.49785134196281433
0.3122793436050415
0.40756678581237793
0.5716471076011658
0.550441563129425
0.3787839710712433
0.4081994593143463
0.4388037621974945
0.5091787576675415
0.37083131074905396
0.38644713163375854
0.524848222732544
0.41779041290283203
0.3404638469219208
0.3281834125518799
0.38314545154571533
0.546842634677887
0.5328224897384644
0.3231203258037567
0.36121514439582825
0.44785255193710327
0.46227821707725525
0.3008664548397064
0.2983981966972351
0.39419126510620117
0.30963829159736633
0.41809508204460144
0.2607799470424652
0.3306540548801422
0.3936608135700226
0.2748624086380005
0.404011994600296
0.4830869734287262
0.46838924288749695
0.43403947353363037
0.44998878240585327
0.4278414249420166
0.43236881494522095
0.31141236424446106
0.5225061178207397
0.5772214531898499
0.4775005877017975
0.28227731585502625
0.3465968072414398
0.3775387108325958
0.5926800966262817
0.42798036336898804
0.4840562045574188
0.37992340326309204
0.37373223900794983
0.3196784555912018
0.31561222672462463
0.33038681745529175
0.5026881098747253
0.4532962441444397
0.3159388601779938

I do not understand why this happens. Could it be that I need to compute the gradients and perform the backpropagation step differently?