Expectation Maximisation for Variational Inference

Hi fellow pytorchians,

I am trying to implement a novel variational inference algorithm that will be used under an Expectation Maximisation scheme.

Without going into too much detail initially my problem boils down to modelling

f(inputs | w1, w2) 

where I want to iteratively update the model via 2 loss function

expectation_loss = loss(w1),
maximisation_loss = loss(w2)

the first w1 is a set of NN parameters and w2 is the a single variational variable.
Would anyone have any tips on implementing this in PyTorch using the torch.nn.Module?