I’m wondering if it is possible to train parameters used to generate a distribution. For example the loc
and scale
values for torch.distributions.Normal()
I have created my own version of torch.nn.Linear()
where my hope was to train a mean and standard deviation to be used in generating the weight values. However it seems that the parameters for the distribution are not tracked as part of the gradient. Any suggestions on how I might accomplish this?
For reference, here is my code for a probabilistic linear layer:
class ProbLinear( torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, n_scale=0.01, **kwargs):
super().__init__()
loc_val = torch.zeros( (out_features, in_features))
scale_val = torch.ones( (out_features, in_features)) * n_scale
self.w_loc = torch.nn.parameter.Parameter( loc_val, requires_grad=True)
self.w_scale = torch.nn.parameter.Parameter( scale_val, requires_grad=True)
if bias:
loc_val = torch.zeros( (out_features))
scale_val = torch.ones( (out_features)) * n_scale
self.b_loc = torch.nn.parameter.Parameter( loc_val, requires_grad=True)
self.b_scale = torch.nn.parameter.Parameter( scale_val, requires_grad=True)
def forward(self, x: torch.Tensor):
weight = torch.distributions.Normal( loc=self.w_loc, scale=self.w_scale).sample().requires_grad_(True)
if hasattr( self, 'b_loc'):
bias = torch.distributions.Normal( loc=self.b_loc, scale=self.b_scale).sample().requires_grad_(True)
return x@weight.T + bias
else:
return x@weight.T