DDP for Generative Models / Sample Generation

Hello,

I am trying to train a new type of generative network called a Generative Flow Network.

Often, the most expensive part of training these networks is generating samples from the networks. I looked into DDP, but what I’d really like is to parallelize the training loop or train_step in pytorch lightning.

For instance, if we were to use DDP on this Deep Energy Model here, would the sampling that happens in the training_step also be parallelized?

I also looked into building a DataLoader that depends on the model/network, but it appears as though each worker would copy the model at step 0 and never update it (since there is no model parameter consistency protocol for those copies).

What is the most effective way to parallelize sample generation (i.e. RL trajectories) in generative models?