Bayesian Hierarchical models in pytorch (BayesianGMM)

I am aware of pyro facilitating probabilistic models through standard SVI inference. But is it possible to write Bayesian models in pure pytorch? Say for instance, MAP training in Bayesian GMM.
I specify a bunch of priors and a likelihood, provide a MAP objective and learn point estimates but I am missing something key in my attempt here, perhaps whats confusing is that means, scales and weights have priors but they still need to be trainable as params in MAP.

import numpy as np
import torch
from tqdm import trange
import torch.distributions as D
from matplotlib.patches import Ellipse
import matplotlib.pyplot as plt
import sklearn.datasets as skd

class BayesianGMM(torch.nn.Module):
    
    def __init__(self, Y, K, dim=2):
        '''
          :param Y (torch.tensor): data
          :param K (int): number of fixed components
          :param D (int): the dimension of data space.
        
        '''
        super(BayesianGMM, self).__init__()
        
        self.Y = Y
        self.K = K
        self.dim = dim
                
        self.mu_prior = torch.zeros((self.K,self.dim))
        self.sig_prior = torch.ones((self.K, self.dim))*3.0
        self.concentration = torch.randint(high=100,size=(1,))
        self.alpha = torch.nn.Parameter(torch.tensor([self.concentration]*self.K).float())
        
        self.means = D.Normal(self.mu_prior, self.sig_prior).sample()
        self.scales = D.Gamma(torch.tensor([1.0]), torch.tensor([1.0])).sample()
        self.weights = D.Dirichlet(self.alpha).sample()
        self.z = D.Categorical(self.weights)
    
    def forward(self, Y):
        
        comp = D.Independent(D.Normal(self.means, self.scales), 1)
        return D.MixtureSameFamily(self.z, comp).log_prob(Y)
    
    def log_prior(self, Y):
        
        self.means.log_prob(Y) + self.scales.log_prob(Y) + self.weights.log_prob(Y)
    
    def get_trainable_param_names(self):
      
      ''' Prints a list of parameters which will be 
      learnt in the process of optimising the objective '''
      
      for name, value in self.named_parameters():
          print(name)     
    
    def loglikelihood(self, Y):
        return self.forward(Y).mean()

    def map_objective(self, Y):
        return self.loglikelihood(Y) + self.log_prior(Y)
      
    def train(self, Y, optimizer, n_steps):
    
        losses = np.zeros(n_steps)
        bar = trange(n_steps, leave=True)
        for step in bar:
            optimizer.zero_grad()
            loss = self.map_objective(Y)
            loss.backward()
            optimizer.step()
            losses[step] = loss
        return losses
        
if __name__ == '__main__':
    
    # Generate some data
    
    Y, labels = skd.make_blobs(n_samples=2000, random_state=42,
                              cluster_std=[2.5, 0.5, 1.0])
    
    Y = torch.tensor(Y).float()
    
    model = BayesianGMM(Y, K=3)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    losses = model.train(Y, optimizer, 4000)
  1. priors are normally added semi-implicitly by adding KL loss terms - same way as in VAE
  2. using distribution objects or sampling in __init__ is incorrect. you also need to use rsample() for gradient flows
  3. these changes still won’t be enough for this model, as discrete distributions won’t train without implementing REINFORCE (surrogate loss) - this is something that pyro would handle with less effort