Using vmap to train an ensemble of models together at the same time

My use case requires me to train a list of models together at the same time. This is a heavily simplified version of my loss function

def loss_function(inputs, targets, models):
    for model in models:
        outputs = model(inputs)
        loss_std += F.cross_entropy(outputs, targets)
    loss_std = loss_std/len(models)
    return loss_std

I was attempting to vectorize the model using something like this

from functorch import combine_state_for_ensemble, vmap

def loss_function(inputs, targets, models):
    fmodel, params, buffers = combine_state_for_ensemble(models)
    outputs = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, inputs)
    loss_std = F.cross_entropy(outputs.mean(dim=0), targets)
    return loss_std

I’m training using resnet18 models, and it contains batch norm layers

when tried to run the new vectorized loss function, I got the following error

Traceback (most recent call last):
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 362, in <module>
    run()
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 359, in run
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, join=True)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 349, in main_worker
    train(nets, ema_nets, trainloader, optimizer, lr_scheduler, scaler, attack)
  File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 137, in train
    loss, _, _ = trades_loss(
  File "/home/nitish/projects/Fast-Adv-Ensemble/utils/loss.py", line 75, in trades_loss
    nat_loss, nat_outputs = adp_loss(
  File "/home/nitish/projects/Fast-Adv-Ensemble/utils/loss.py", line 50, in adp_loss
    outputs = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, inputs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 362, in wrapped
    return _flat_vmap(
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 489, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/make_functional.py", line 282, in forward
    return self.stateless_model(*args, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/projects/Fast-Adv-Ensemble/models/resnet.py", line 235, in forward
    return self._forward_impl(x)
  File "/home/nitish/projects/Fast-Adv-Ensemble/models/resnet.py", line 219, in _forward_impl
    x = self.bn1(x)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 740, in forward
    return F.batch_norm(
  File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/functional.py", line 2450, in batch_norm
    return torch.batch_norm(
RuntimeError: NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format

I’m guessing NYI means Not Yet Implemented. Is there a way to make it work for my case? I am currently using version 1.13.1. Is it implemented in 2.0.0 or in the nightly builds?

Please let me know if there is a fix for this or an alternative approach I could try out to vectorize/ speedup the function.

I don’t know if this method is already implemented but it’s certainly a good idea to try it out using the latest nightly release. In case you want to keep your older PyTorch installation you could create a new virtual environment and install the binaries there.

I have tested it out with the latest version, but the error persists. I am not familiar with what contiguous means in this context. My understanding was that classical format is contiguous, where data of the same channel are sequentially ordered, making it channel first memory and channel last the exact opposite. Can you help me understand what the error means in this context?

Things I am also using along with the vmap function:

  1. Mixed precision training
  2. channel last memory format ( I have tried switching this off but i still run into the same error)
  3. DistributedDataParallel