Hi fellow pytorchians,
I am trying to implement a novel variational inference algorithm that will be used under an Expectation Maximisation scheme.
Without going into too much detail initially my problem boils down to modelling
f(inputs | w1, w2)
where I want to iteratively update the model via 2 loss function
expectation_loss = loss(w1),
maximisation_loss = loss(w2)
the first w1 is a set of NN parameters and w2 is the a single variational variable.
Would anyone have any tips on implementing this in PyTorch using the torch.nn.Module?