Questions on Model Ensembling Tutorial using vmap


I stumbled on the model ensembling tutorial from the official PyTorch website. It stacks the parameters of the ensemble that share all the same architecture and uses vmap to vectorise the forward propagation. The blog entry promises a significant speedup and it looked quite fitting for what I’m trying since I’m working with something quite similar to an ensemble.

I have a few questions about that tutorial which I found at least not easy to answer for myself.

The forward propagation worked just fine

Consequently, I thought: Wait could you use the same technique for a training loop speeding up that as well?

So, I did the same thing with basically the same input for all models and used the prediction as a prediction in the loss. Each model has its own optimizer. So I zeroed the gradients of each model in a loop. Then I calculated the loss and wanted to calculate the gradient. Afterwards, I iterated over the optimizers and called step.

Admittedly, I probably already messed up for the loss since I didn’t deactivate the reduction… I want each model to have its own loss and being not effected by that of the others.

However, more problematic: The gradient calculation seemed not to work. The models received no gradients and their gradient fields remained None.

So my question is: Does this idea work at all for backpropagation or does it break the computational graph somehow?

Aside of that, the operation is vectorised. Does that mean that PyTorch parallelises the different forward propagations on-the-fly for one GPU? Is it one tensor operation in the end? Or is it still basically syntactic sugar for looping but done in C++?

Thanks for reading :slight_smile:

Were you able to find any solutions for this? I am also facing the same issue while trying to speed up my training using vmap.

1 Like