MCMC Package: Efficient&Fast Storing of StateDicts for accepting/rejecting model samples

Hi there everybody,

I’ve been implementing a MCMC sampling package for PyTorch models for fun.

So far I’ve been doing multi-processing for parallel chains on my cpu for little toy problems.

The next step was to modify the package such that it runs on GPU and/or multiple GPUs.
So far the idea is to call the convenient state_dict function every time a parameter sample is accepted and loading an old one every time it is rejected.
This includes moving a lot of data between CPU and GPU which isn’t a huge problem given the bottleneck of the forward pass (the speed up from using a GPU and the slower storing/loading should be larger than running it on a CPU).

Since I want to make it as fast as possible I was wondering whether there are any caveats/gotchas/things-to-pay-attention-to considering the repeated loading storing of state_dicts/parameters in PyTorch.

So far my “smartest” idea was to keep a parallel model with the last accepted model parameters directly on the GPU (which admittedly reduces the GPU memory available for the actual forward/backward pass) but circumvents loading the parameters from the CPU every time we reject a sampling move.

Does anybody have any recommendations concerning how to speed up loading/storing of state_dicts/parameters?

Thanks in advance! =)

1 Like