Probabilistic layer, learning distribution parameters?

I’m wondering if it is possible to train parameters used to generate a distribution. For example the loc and scale values for torch.distributions.Normal()

I have created my own version of torch.nn.Linear() where my hope was to train a mean and standard deviation to be used in generating the weight values. However it seems that the parameters for the distribution are not tracked as part of the gradient. Any suggestions on how I might accomplish this?

For reference, here is my code for a probabilistic linear layer:

class ProbLinear( torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True, n_scale=0.01, **kwargs):
        super().__init__()
        loc_val = torch.zeros( (out_features, in_features))
        scale_val = torch.ones( (out_features, in_features)) * n_scale
        self.w_loc =  torch.nn.parameter.Parameter( loc_val, requires_grad=True)
        self.w_scale = torch.nn.parameter.Parameter( scale_val, requires_grad=True)
        
        if bias:
            loc_val = torch.zeros( (out_features))
            scale_val = torch.ones( (out_features)) * n_scale
            self.b_loc =  torch.nn.parameter.Parameter( loc_val, requires_grad=True)
            self.b_scale = torch.nn.parameter.Parameter( scale_val, requires_grad=True)

    def forward(self, x: torch.Tensor):
        weight = torch.distributions.Normal( loc=self.w_loc, scale=self.w_scale).sample().requires_grad_(True)
        if hasattr( self, 'b_loc'):
            bias = torch.distributions.Normal( loc=self.b_loc, scale=self.b_scale).sample().requires_grad_(True)
            return x@weight.T + bias
        else:
            return x@weight.T

I’m wondering if it is possible to train parameters used to generate a distribution.

You can do that however you’ll need to tweak a few things in you implementation:-

  • So you’re calling requires_grad_(True) on the sample which is not needed because you want to update the loc and scale not the sample. You can remove that.

it seems that the parameters for the distribution are not tracked as part of the gradient.

  • That’s because there is no graph being created. If you check the code for sample the whole normal call is wrapped in torch.no_grad. But there is a fix you can use rsample instead.

Rest seems good. So your code becomes:-

class ProbLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True, n_scale=0.01, **kwargs):
        super().__init__()
        loc_val = torch.zeros((out_features, in_features))
        scale_val = torch.ones((out_features, in_features)) * n_scale
        self.w_loc = torch.nn.parameter.Parameter( loc_val, requires_grad=True)
        self.w_scale = torch.nn.parameter.Parameter( scale_val, requires_grad=True)
        
        if bias:
            loc_val = torch.zeros( (out_features))
            scale_val = torch.ones( (out_features)) * n_scale
            self.b_loc = torch.nn.parameter.Parameter( loc_val, requires_grad=True)
            self.b_scale = torch.nn.parameter.Parameter( scale_val, requires_grad=True)

    def forward(self, x: torch.Tensor):
        weight = torch.distributions.Normal( loc=self.w_loc, scale=self.w_scale).rsample()
        if hasattr(self, 'b_loc'):
            bias = torch.distributions.Normal( loc=self.b_loc, scale=self.b_scale).rsample()
            return x@weight.T + bias
        else:
            return x@weight.T

Hope it helps!

1 Like

Fantastic! Thank you. It does seem to work. Now to see if I can actually get it to optimize.

1 Like

For anyone who may find this and wish to explore… Here is the very simple test case I used to verify that it will actually work ( though it may be very inefficient).

Using a slightly modified ProbLinear() as mentioned by @krypticmouse and a simple sine dataset…

class ProbLinear( torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True, **kwargs):
        super().__init__()
        self.eps = torch.tensor( [1e-6])
        self.w_val = torch.nn.parameter.Parameter( torch.randn( (out_features, in_features)).type( torch.float32), requires_grad=True)
        self.w_std = torch.nn.parameter.Parameter( torch.ones_like( self.w_val) * torch.Tensor([np.sqrt(6 / in_features)]).type( torch.float32), requires_grad=True)
        
        if bias:
            self.b_val =  torch.nn.parameter.Parameter( torch.randn( (out_features)).type( torch.float32), requires_grad=True)
            self.b_std = torch.nn.parameter.Parameter( torch.ones_like( self.b_val) * torch.Tensor([np.sqrt(6 / in_features)]).type( torch.float32), requires_grad=True)

    def forward(self, x: torch.Tensor):
        weight = torch.distributions.Uniform( low=self.w_val-torch.maximum( self.w_std, self.eps), high=self.w_val+torch.maximum(self.w_std, self.eps)).rsample()
        if hasattr( self, 'b_val'):
            bias = torch.distributions.Uniform( low=self.b_val-torch.maximum(self.b_std, self.eps), high=self.b_val+torch.maximum(self.b_std, self.eps)).rsample()
            return x@weight.T + bias
        else:
            return x@weight.T
X = torch.linspace( -math.pi, math.pi, 100).view( -1, 1)
Y = torch.sin( torch.linspace( -math.pi, math.pi, 100))
model = torch.nn.Sequential( ProbLinear( 1, 35), torch.nn.Tanh(), ProbLinear( 35, 1), torch.nn.Tanh())

optim = torch.optim.Adam( params=model.parameters(), lr=1e-3)
loss_values = []
max_e = 10000
for e in range(max_e):
    optim.zero_grad()
    preds = torch.vstack( [ model( X).view( -1) for i in range(100)])
    loss = ((Y - preds.mean( dim=0).view(-1))**2).mean()
    loss.backward()
    loss_values.append( loss.item())
    print( f"{(e/max_e*100.0):0.2f}% : {loss.item():0.5f}", end=f"{' '*30}\r")
    
    optim.step()

final loss: 0.001245

Looking at the results, the mean prediction does come close to the target, and you can see that there is a wide variation in the individual set of predictions ( 100 samples).

preds = torch.vstack( [ model( X).view(-1) for i in range(100)])

1 Like