Gradient checkpointing Bayesian Neural Network

Hi,
I am considering the use of gradient checkpointing to lessen the VRAM load. From what I understand there were some issues with stochastic nodes (e.g. Dropout) at some point in time to apply gradient checkpointing.
However I have a kind of Bayesian Neural Network which needs quite a bit of memory, hence I am interested in gradient checkpointing. Since all the weights are bayesian, there is stochasticity everywhere which can’t be removed.
Can I safely apply gradient checkpointing to this Bayesian Neural Network ?

Based on the description from the docs it seems the default behavior would be a deterministic dropout, which can be disabled:

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=False to checkpoint or checkpoint_sequential to omit stashing and restoring the RNG state during each checkpoint.

So it seems your use case should be fine.
Could you link to the known issues with stochastic nodes? I remember there were some issues, but couldn’t find these issues quickly.

Thanks for your reply ! :slight_smile:
I found the issue discussed in this tutorial, section “Handling a few special layers in checkpointing” : https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb

Thanks for the link!
While this notebook gives you a really good example of the usage, note that it’s a bit outdated by now and if I’m not mistaken, @mcarilli’s PR should have enabled the bitwise accuracy between standard models and checkpointed models.

These tests also should verify this behavior.

1 Like

Thank you very much for the informations ! :slight_smile: It worked pretty well from my first few experiments.
Have a nice day :slight_smile: