My use case requires me to train a list of models together at the same time. This is a heavily simplified version of my loss function
def loss_function(inputs, targets, models):
for model in models:
outputs = model(inputs)
loss_std += F.cross_entropy(outputs, targets)
loss_std = loss_std/len(models)
return loss_std
I was attempting to vectorize the model using something like this
from functorch import combine_state_for_ensemble, vmap
def loss_function(inputs, targets, models):
fmodel, params, buffers = combine_state_for_ensemble(models)
outputs = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, inputs)
loss_std = F.cross_entropy(outputs.mean(dim=0), targets)
return loss_std
I’m training using resnet18 models, and it contains batch norm layers
when tried to run the new vectorized loss function, I got the following error
Traceback (most recent call last):
File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 362, in <module>
run()
File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 359, in run
torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, join=True)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
while not context.join():
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 160, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 349, in main_worker
train(nets, ema_nets, trainloader, optimizer, lr_scheduler, scaler, attack)
File "/home/nitish/projects/Fast-Adv-Ensemble/train_adv_ensemble.py", line 137, in train
loss, _, _ = trades_loss(
File "/home/nitish/projects/Fast-Adv-Ensemble/utils/loss.py", line 75, in trades_loss
nat_loss, nat_outputs = adp_loss(
File "/home/nitish/projects/Fast-Adv-Ensemble/utils/loss.py", line 50, in adp_loss
outputs = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, inputs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 362, in wrapped
return _flat_vmap(
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 35, in fn
return f(*args, **kwargs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/vmap.py", line 489, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/functorch/_src/make_functional.py", line 282, in forward
return self.stateless_model(*args, **kwargs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nitish/projects/Fast-Adv-Ensemble/models/resnet.py", line 235, in forward
return self._forward_impl(x)
File "/home/nitish/projects/Fast-Adv-Ensemble/models/resnet.py", line 219, in _forward_impl
x = self.bn1(x)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 740, in forward
return F.batch_norm(
File "/home/nitish/mambaforge/envs/fast/lib/python3.9/site-packages/torch/nn/functional.py", line 2450, in batch_norm
return torch.batch_norm(
RuntimeError: NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format
I’m guessing NYI means Not Yet Implemented. Is there a way to make it work for my case? I am currently using version 1.13.1. Is it implemented in 2.0.0 or in the nightly builds?
Please let me know if there is a fix for this or an alternative approach I could try out to vectorize/ speedup the function.