Make torch linear layer probabilistic

I’d like to make a native torch nn.Linear layer probabilistic by

  1. Adding to the layer weight and bias distribution objects using register_module method. For example:
layer.register_module('weight_dist', some_weight_dist_obj)
layer.register_module('bias_dist', some_bias_dist_obj)

Where some_weight_dist_obj and some_bias_dist_obj are objects of the class GaussianVariable

class GaussianVariable(AbstractVariable):
    def sample(self) -> Tensor:
        epsilon = torch.randn_like(self.sigma)
        return + self.sigma * epsilon. # and self.sigma are trainable nn.Parameter objects
  1. Replacing the old forward with a new one:
def new_forward(self, input: torch.Tensor) -> torch.Tensor:
    sampled_weight = self.weight_dist.sample()
    sampled_bias = self.bias_dist.sample()
    return F.linear(input, a, b)
layer.forward = new_forward.__get__(layer, nn.Module)

In this setting I don’t face any problem during network training i.e. and self.sigma are being updated.

However, I don’t want to reimplement a forward function as I showed in #2 (For nn.Linear it is ok, but next I want to use also other, more complex layers). Rather I would like to use nn.Linear.forward again. For that, I tried many options, but none of them worked:

  1. I did not replace forward with new_forward, but used hooks:
def forward_hook(module, inputs):
    nn.utils.vector_to_parameters(module.weight_dist.sample().flatten(), module.weight)
    nn.utils.vector_to_parameters(module.bias_dist.sample().flatten(), module.bias)
    return inputs
  1. Used the following new_forward:
def new_forward(self, input: torch.Tensor) -> torch.Tensor:
    nn.utils.vector_to_parameters(self.weight_dist.sample().flatten(), self.weight)
    nn.utils.vector_to_parameters(self.bias_dist.sample().flatten(), self.bias)
    return nn.Linear.forward(self, input)
  1. Used different assignments:
nn.utils.vector_to_parameters(self.weight_dist.sample().flatten(), self.weight)
or = self.weight_dist.sample()
self.weight = nn.Parameter(self.weight_dist.sample())

Could you help me with that? How to properly reuse forward method in this setting?

P.S. I don’t want to inherit from nn.Linear and create own ProbLinear or something