I have the following module in my Deep Neural network model to compute the log likelihood of Gaussian Mixture model.

```
def loglikelihood_gmm(self, x, mu, logvar, pi):
# init
loglike = 0
# for all data channels
for n in range(x.shape[1]):
# likelihood of a single mixture at evaluation point
assert not torch.isnan(mu[:, n, :]).any()
assert not torch.isnan(logvar[:, n, :]).any()
pred_dist = tdist.Normal(mu[:, n, :], logvar[:, n, :].exp().sqrt())
x_mod = torch.mm(x[:, n].unsqueeze(1), torch.ones(1, self.n_mixtures, device=self.device))
like = pred_dist.log_prob(x_mod)
# weighting by probability of mixture and summing
temp = (pi[:, n, :] * like)
temp = temp.sum()
# log-likelihood added to previous log-likelihoods
loglike = loglike + temp
return loglike
```

However, I got out of memory error using this function

```
/tmp/ipykernel_86712/2637053337.py in loglikelihood_gmm(self, x, mu, logvar, pi)
738 x_mod = torch.mm(x[:, n].unsqueeze(1), torch.ones(1, self.n_mixtures, device=self.device))
739
--> 740 like = pred_dist.log_prob(x_mod)
741 # weighting by probability of mixture and summing
742 temp = (pi[:, n, :] * like)
~/anaconda3/lib/python3.9/site-packages/torch/distributions/normal.py in log_prob(self, value)
81 var = (self.scale ** 2)
82 log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
---> 83 return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
84
85 def cdf(self, value):
OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 11.76 GiB total capacity; 10.57 GiB already allocated; 1.94 MiB free; 10.77 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
```

Is there any other efficient way to write this function to avoid getting this error?

Update:

I thought maybe using the MixtureSameFamily distribution from pytorch can solve the cuda memory error for computing the log-likelihood of Gaussian mixture model (GMM). Here is my attempt:

```
def loglikelihood_gmm(self, x, mu, logvar, pi):
batch_shape=x.shape[0]
num_components=mu.shape[2]
n_dim=x.shape[1]
# for all data channels
# likelihood of a single mixture at evaluation point
assert not torch.isnan(mu).any()
assert not torch.isnan(logvar).any()
means = mu.permute(0,2,1) #batch_shape, num_components, dim
sigma = logvar.mul(0.5).exp_()
scale = sigma.permute(0,2,1) #torch.Size([B, n_comp, dim])
#(batch_size, n_components, dim, dim)
A = torch.diag_embed(scale, dim1=-1, dim2=-2)
covar = A.transpose(2,3).matmul(A) #torch.Size([batch_size, n_components, dim, dim])
weights = pi.permute(0,2,1) #(Batch_size, n_comp, dim)
print(scale.shape,scale.tril(-1).shape, A.shape, weights.shape, means.shape)
# [*batch_dims, comp, dim] --> [*batch_dims, 1, comp, dim] to evaluate batched inputs
mix = tdist.Categorical(weights.unsqueeze(-2))
comp = tdist.MultivariateNormal(means, covar)
#x: Batch, dim ===> Batch,n_comp, dim
gmm = tdist.MixtureSameFamily(mix, comp)
x_mod = x.unsqueeze(1).repeat(1,self.n_mixtures,1)
x = gmm._pad(x_mod) # noqa [*batch,n_comp,1, dim]
return -gmm.log_prob(x).mean()
```

Now I get this error which I canâ€™t figure out how to solve it:

```
torch.Size([1, 15, 12]) torch.Size([1, 15, 12]) torch.Size([1, 15, 12, 12]) torch.Size([1, 15, 12]) torch.Size([1, 15, 12])
---------------------------------------------------------------------------
/tmp/ipykernel_93455/3995549603.py in forward(self, u, y)
--> 637 loss_pred = self.loglikelihood_gmm(y[:, :, t], dec_mean_t, dec_logvar_t, dec_pi_t)
638 loss += - loss_pred + KLD
639
/tmp/ipykernel_93455/3995549603.py in loglikelihood_gmm(self, x, mu, logvar, pi)
750 comp = tdist.MultivariateNormal(means, covar)
751 #x: Batch, dim ===> Batch,n_comp, dim
--> 752 gmm = tdist.MixtureSameFamily(mix, comp)
753 x_mod = x.unsqueeze(1).repeat(1,self.n_mixtures,1)
754 x = gmm._pad(x_mod) # noqa [*batch,n_comp,1, dim]
~/anaconda3/lib/python3.9/site-packages/torch/distributions/mixture_same_family.py in __init__(self, mixture_distribution, component_distribution, validate_args)
80 kc = self._component_distribution.batch_shape[-1]
81 if km is not None and kc is not None and km != kc:
---> 82 raise ValueError("`mixture_distribution component` ({0}) does not"
83 " equal `component_distribution.batch_shape[-1]`"
84 " ({1})".format(km, kc))
ValueError: `mixture_distribution component` (12) does not equal `component_distribution.batch_shape[-1]` (15)
```

Any suggestion for solving either way to compute the log-likelihood of GMM?