Using vmap to train an ensemble of models together at the same time

My use case requires me to train a list of models together at the same time. This is a heavily simplified version of my loss function

def loss_function(inputs, targets, models):
    for model in models:
        outputs = model(inputs)
        loss_std += F.cross_entropy(outputs, targets)
    loss_std = loss_std/len(models)
    return loss_std

I was attempting to vectorize the model using something like this

from functorch import combine_state_for_ensemble, vmap

def loss_function(inputs, targets, models):
    fmodel, params, buffers = combine_state_for_ensemble(models)
    outputs = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, inputs)
    loss_std = F.cross_entropy(outputs.mean(dim=0), targets)
    return loss_std

I’m training using resnet18 models, and it contains batch norm layers

when tried to run the new vectorized loss function, I got the following error

Traceback (most recent call last):
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 362, in <module>
    run()
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 359, in run
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, join=True)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 349, in main_worker
    train(nets, ema_nets, trainloader, optimizer, lr_scheduler, scaler, attack)
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 137, in train
    loss, _, _ = trades_loss(
  File "/home/nitish/projects/Fast-Adv-Ensemble/utils/loss.py", line 75, in trades_loss
    nat_loss, nat_outputs = adp_loss(
  File "/home/nitish/projects/Fast-Adv-Ensemble/utils/loss.py", line 50, in adp_loss
    outputs = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, inputs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 362, in wrapped
    return _flat_vmap(
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 489, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/make_functional.py", line 282, in forward
    return self.stateless_model(*args, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/projects/Fast-Adv-Ensemble/models/resnet.py", line 235, in forward
    return self._forward_impl(x)
  File "/home/nitish/projects/Fast-Adv-Ensemble/models/resnet.py", line 219, in _forward_impl
    x = self.bn1(x)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 740, in forward
    return F.batch_norm(
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/functional.py", line 2450, in batch_norm
    return torch.batch_norm(
RuntimeError: NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format

I’m guessing NYI means Not Yet Implemented. Is there a way to make it work for my case? I am currently using version 1.13.1. Is it implemented in 2.0.0 or in the nightly builds?

Please let me know if there is a fix for this or an alternative approach I could try out to vectorize/ speedup the function.

I don’t know if this method is already implemented but it’s certainly a good idea to try it out using the latest nightly release. In case you want to keep your older PyTorch installation you could create a new virtual environment and install the binaries there.

I have tested it out with the latest version, but the error persists. I am not familiar with what contiguous means in this context. My understanding was that classical format is contiguous, where data of the same channel are sequentially ordered, making it channel first memory and channel last the exact opposite. Can you help me understand what the error means in this context?

Things I am also using along with the vmap function:

  1. Mixed precision training
  2. channel last memory format ( I have tried switching this off but i still run into the same error)
  3. DistributedDataParallel

Hello, are there any further developments on this? I have a very similar problem where I want to train a dozen of smaller networks on the same data, which should be done in parallel. jax offers a vmap function for this and a nice tutorial on how to do it on gpu, but what about pytorch? I’ve scanned the internet and it seems there are no clear solutions if you want to do it with pytorch… Any help would be appreciated!

Training model ensembling via vmap should be supported in stable releases now as this tutorial indicates.

Thanks, it works very good!

Thanks once again for that. Do you have any idea how to backpropagate and update the gradients in a similar manner? When training my ensemble, the forward call is indeed much faster when using vmap, with respect to classical looping over models, but it seems that the loss.backward() and especially optimizer.step() calls are equally slow as before (to be expected). Is there any way to accelerate this using functional optimizers maybe? Thank you!

I haven’t tried it myself as I just start working on vmap ensemble model idea. But maybe combine the loss of each individual model into a big loss function and do backward pass for all models together?

Hi Yu, I have done this, but I wonder if this could be accelerated somehow… I observe that optimizer.step() call takes quite some time (especially for wider ensemble member architectures). I know that functional optimizers exist (from torchopt) but I do not have any idea if this would be useful here or if they could vectorize optimizer.step() function… Btw, if you have troubles using vmap to train an ensemble, feel free to ask, I have implemented a working version of this…

Hi, Would you be able to share your code that looks at this ? I am a bit unsure as to how to deploy the weights within the optimizer when using the ensembles.