How to optimize the input distribution?

Hello, I’ m trying to optimize an input distribution as a side information of an (pertubed) input image to help reconstruction. The loss function is L2 loss.

When the input distrubtution is Normal Distribution (and sample directly with torch.randn()), everything works well.

I’m now trying to extend it to GMM distribution with learnable parameters. My implementation is as follows:

import torch.distributions as D
class GaussianMixture(nn.Module):
    def __init__(self, opt):
        super(GaussianMixture, self).__init__()
        if opt['gmm_init']:
            self.gmm_para = nn.Parameter(torch.Tensor([opt['pis'], opt['mus'], opt['logvars']]))
            if opt['gmm_learnable']: # pertubation for start
                with torch.no_grad(): self.gmm_para += torch.randn(self.gmm_para.shape)
        else:
            self.gmm_para = nn.Parameter(torch.randn(3, opt['n_components'])) 
    
    def refresh(self):
        self.pis, self.mus, self.logvars = self.gmm_para

    def build(self):
        self.refresh()
        std = torch.exp(0.5*torch.clamp(self.logvars, -20, 20))
        weights = self.pis.softmax(dim=0)
        mix = D.Categorical(weights)
        comp = D.Normal(self.mus, std)
        self.gmm = D.MixtureSameFamily(mix, comp)
        
    def sample(self, zshape):
        self.build()
        return self.gmm.sample(zshape)

It has its own optimizer. At each iter, I sample a sequence with current learned gmm parameters to help reconstruction, and backward the loss and then step the optimizer.

However, I found all the parameters tend to be optimized to 0. As the implementation before, it may also be reasonable, since the resulting distribution is just a single Normal Didtribution N(0, 1).

But when I optimize std directly (rather than logvar = log(std**2)), it also converges to 0, instead of 1 that forms a Normal Distribution. The new codes are as follows:

class GaussianMixture(nn.Module):
    def __init__(self, opt):
        super(GaussianMixture, self).__init__()
        if opt['gmm_init']:
            self.gmm_para = nn.Parameter(torch.Tensor([opt['pis'], opt['mus'], opt['logvars']]))
            if opt['gmm_learnable']: # pertubation for start
                with torch.no_grad(): self.gmm_para += torch.randn(self.gmm_para.shape) # std 0.0001
        else:
            self.gmm_para = nn.Parameter(torch.randn(3, opt['n_components'])) # std 0.0001
    
    def refresh(self):
        self.pis, self.mus, self.logvars = self.gmm_para

    def build(self):
        self.refresh()
        std = torch.clamp(self.logvars, 0.1, 20)
        weights = self.pis.softmax(dim=0)
        mix = D.Categorical(weights)
        comp = D.Normal(self.mus, std)
        self.gmm = D.MixtureSameFamily(mix, comp)
        
    def sample(self, zshape):
        self.build()
        return self.gmm.sample(zshape)

Therefore, I think there might be some mistake here, pushing all the parameter of input distribution towards 0.

It has confused me a lot, can you kindly help me figure it out?

Thanks in advance!