# Probabilistic layer, learning distribution parameters?

I’m wondering if it is possible to train parameters used to generate a distribution. For example the `loc` and `scale` values for `torch.distributions.Normal()`

I have created my own version of `torch.nn.Linear()` where my hope was to train a mean and standard deviation to be used in generating the weight values. However it seems that the parameters for the distribution are not tracked as part of the gradient. Any suggestions on how I might accomplish this?

For reference, here is my code for a probabilistic linear layer:

``````class ProbLinear( torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, n_scale=0.01, **kwargs):
super().__init__()
loc_val = torch.zeros( (out_features, in_features))
scale_val = torch.ones( (out_features, in_features)) * n_scale

if bias:
loc_val = torch.zeros( (out_features))
scale_val = torch.ones( (out_features)) * n_scale

def forward(self, x: torch.Tensor):
if hasattr( self, 'b_loc'):
return x@weight.T + bias
else:
return x@weight.T
``````

I’m wondering if it is possible to train parameters used to generate a distribution.

You can do that however you’ll need to tweak a few things in you implementation:-

• So you’re calling `requires_grad_(True)` on the sample which is not needed because you want to update the `loc` and `scale` not the sample. You can remove that.

it seems that the parameters for the distribution are not tracked as part of the gradient.

• That’s because there is no graph being created. If you check the code for `sample` the whole normal call is wrapped in `torch.no_grad`. But there is a fix you can use `rsample` instead.

Rest seems good. So your code becomes:-

``````class ProbLinear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, n_scale=0.01, **kwargs):
super().__init__()
loc_val = torch.zeros((out_features, in_features))
scale_val = torch.ones((out_features, in_features)) * n_scale

if bias:
loc_val = torch.zeros( (out_features))
scale_val = torch.ones( (out_features)) * n_scale

def forward(self, x: torch.Tensor):
weight = torch.distributions.Normal( loc=self.w_loc, scale=self.w_scale).rsample()
if hasattr(self, 'b_loc'):
bias = torch.distributions.Normal( loc=self.b_loc, scale=self.b_scale).rsample()
return x@weight.T + bias
else:
return x@weight.T
``````

Hope it helps!

1 Like

Fantastic! Thank you. It does seem to work. Now to see if I can actually get it to optimize.

1 Like

For anyone who may find this and wish to explore… Here is the very simple test case I used to verify that it will actually work ( though it may be very inefficient).

Using a slightly modified `ProbLinear()` as mentioned by @krypticmouse and a simple sine dataset…

``````class ProbLinear( torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, **kwargs):
super().__init__()
self.eps = torch.tensor( [1e-6])
self.w_val = torch.nn.parameter.Parameter( torch.randn( (out_features, in_features)).type( torch.float32), requires_grad=True)
self.w_std = torch.nn.parameter.Parameter( torch.ones_like( self.w_val) * torch.Tensor([np.sqrt(6 / in_features)]).type( torch.float32), requires_grad=True)

if bias:
self.b_val =  torch.nn.parameter.Parameter( torch.randn( (out_features)).type( torch.float32), requires_grad=True)
self.b_std = torch.nn.parameter.Parameter( torch.ones_like( self.b_val) * torch.Tensor([np.sqrt(6 / in_features)]).type( torch.float32), requires_grad=True)

def forward(self, x: torch.Tensor):
weight = torch.distributions.Uniform( low=self.w_val-torch.maximum( self.w_std, self.eps), high=self.w_val+torch.maximum(self.w_std, self.eps)).rsample()
if hasattr( self, 'b_val'):
bias = torch.distributions.Uniform( low=self.b_val-torch.maximum(self.b_std, self.eps), high=self.b_val+torch.maximum(self.b_std, self.eps)).rsample()
return x@weight.T + bias
else:
return x@weight.T
``````
``````X = torch.linspace( -math.pi, math.pi, 100).view( -1, 1)
Y = torch.sin( torch.linspace( -math.pi, math.pi, 100))
model = torch.nn.Sequential( ProbLinear( 1, 35), torch.nn.Tanh(), ProbLinear( 35, 1), torch.nn.Tanh())

loss_values = []
max_e = 10000
for e in range(max_e):
preds = torch.vstack( [ model( X).view( -1) for i in range(100)])
loss = ((Y - preds.mean( dim=0).view(-1))**2).mean()
loss.backward()
loss_values.append( loss.item())
print( f"{(e/max_e*100.0):0.2f}% : {loss.item():0.5f}", end=f"{' '*30}\r")

optim.step()
``````

final loss: `0.001245`

Looking at the results, the mean prediction does come close to the target, and you can see that there is a wide variation in the individual set of predictions ( 100 samples).

``````preds = torch.vstack( [ model( X).view(-1) for i in range(100)])
``````

1 Like