I need to train several neural networks with the same structure and with the same input. Training one by one takes quite a long time and I found that using model ensembling would be a good option here. However, when I try it, the models are not optimizing. I provide this simple example:
import torch as th
import torch.nn as nn
from torch.func import stack_module_state, functional_call
import sys
import copy
vectorized = False
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(2,1)
def forward(self, x):
return th.sigmoid(self.fc(x))
models = [Net().to("cuda") for _ in range(1)]
models = nn.ModuleList(models)
optimizer = th.optim.Adam(models.parameters(), lr=0.05)
if vectorized:
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), x)
for epoch in range(100):
data = th.rand(1,2) * 2 - 1
data = data.to("cuda")
params, buffers = stack_module_state(models)
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')
loss = th.vmap(fmodel, in_dims=(0, 0, None))(params, buffers, data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.item())
else:
for epoch in range(100):
data = th.rand(1,2) * 2 - 1
data = data.to("cuda")
for model in models:
loss = model(data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.item())
When I set vectorized=False
, the loss behaves as follows:
0.468487024307251
0.5468327403068542
0.4666518270969391
0.30017247796058655
0.42256370186805725
0.41682013869285583
0.3572870194911957
0.37354421615600586
0.39447021484375
0.3872259855270386
0.3910242021083832
0.3117615282535553
0.22395209968090057
0.28443443775177
0.26190704107284546
0.2749706506729126
0.28424859046936035
0.2088702917098999
0.29598936438560486
0.28152528405189514
0.1965540051460266
0.2668392062187195
0.2258613258600235
0.20860962569713593
0.17170515656471252
0.18912245333194733
0.18060843646526337
0.18403497338294983
0.16174425184726715
0.15166039764881134
0.15438508987426758
0.14568376541137695
0.1424552947282791
0.1341356784105301
0.13102765381336212
0.13481976091861725
0.12841664254665375
0.12801849842071533
0.11013973504304886
0.09837993234395981
0.09889375418424606
0.09404817968606949
0.11379320919513702
0.10520896315574646
0.0864768698811531
0.10952682793140411
0.08476637303829193
0.09675952792167664
0.08392863720655441
0.070079006254673
0.1028694212436676
0.0883655995130539
0.06897855550050735
0.0695270448923111
0.07020990550518036
0.08362960815429688
0.08411870151758194
0.07858551293611526
0.07350816577672958
0.07223667204380035
0.060547616332769394
0.056986112147569656
0.06019512191414833
0.05476418137550354
0.05658810958266258
0.05796775221824646
0.049511805176734924
0.04900973290205002
0.05014628544449806
0.059480972588062286
0.05562940984964371
0.04808162525296211
0.04888279363512993
0.045263487845659256
0.04549892246723175
0.0463012270629406
0.04478741064667702
0.04320802912116051
0.0414426214993
0.04260075464844704
0.04074430093169212
0.04047136381268501
0.038486938923597336
0.03848464414477348
0.03592165187001228
0.04080507159233093
0.03803026303648949
0.04044129326939583
0.04093775153160095
0.038270220160484314
0.03303438797593117
0.03105958364903927
0.03366934508085251
0.03498108685016632
0.03002573549747467
0.029859410598874092
0.030309630557894707
0.03262103721499443
0.03157965466380119
0.030938366428017616
When I set vectorized=True
, the loss seems to oscillate:
0.39742761850357056
0.5150707364082336
0.33502712845802307
0.40068161487579346
0.3729161322116852
0.26942333579063416
0.3000563383102417
0.26778095960617065
0.2689218819141388
0.30592694878578186
0.3212797939777374
0.48540589213371277
0.3099740445613861
0.36055657267570496
0.37235185503959656
0.5096343159675598
0.44988110661506653
0.41553425788879395
0.4014529883861542
0.2469296157360077
0.25996655225753784
0.49726536870002747
0.3670669496059418
0.45995020866394043
0.4013332426548004
0.333894819021225
0.44986388087272644
0.3017444610595703
0.43692031502723694
0.4194122850894928
0.4580802321434021
0.2955131530761719
0.3925922214984894
0.42527011036872864
0.5408198237419128
0.4509319067001343
0.40136289596557617
0.3807404637336731
0.48967868089675903
0.27931997179985046
0.26397600769996643
0.5145940780639648
0.23822587728500366
0.29648733139038086
0.49785134196281433
0.3122793436050415
0.40756678581237793
0.5716471076011658
0.550441563129425
0.3787839710712433
0.4081994593143463
0.4388037621974945
0.5091787576675415
0.37083131074905396
0.38644713163375854
0.524848222732544
0.41779041290283203
0.3404638469219208
0.3281834125518799
0.38314545154571533
0.546842634677887
0.5328224897384644
0.3231203258037567
0.36121514439582825
0.44785255193710327
0.46227821707725525
0.3008664548397064
0.2983981966972351
0.39419126510620117
0.30963829159736633
0.41809508204460144
0.2607799470424652
0.3306540548801422
0.3936608135700226
0.2748624086380005
0.404011994600296
0.4830869734287262
0.46838924288749695
0.43403947353363037
0.44998878240585327
0.4278414249420166
0.43236881494522095
0.31141236424446106
0.5225061178207397
0.5772214531898499
0.4775005877017975
0.28227731585502625
0.3465968072414398
0.3775387108325958
0.5926800966262817
0.42798036336898804
0.4840562045574188
0.37992340326309204
0.37373223900794983
0.3196784555912018
0.31561222672462463
0.33038681745529175
0.5026881098747253
0.4532962441444397
0.3159388601779938
I do not understand why this happens. Could it be that I need to compute the gradients and perform the backpropagation step differently?