I’d like to make a native torch nn.Linear layer
probabilistic by
- Adding to the
layer
weight and bias distribution objects usingregister_module
method. For example:
layer.register_module('weight_dist', some_weight_dist_obj)
layer.register_module('bias_dist', some_bias_dist_obj)
Where some_weight_dist_obj and some_bias_dist_obj are objects of the class GaussianVariable
class GaussianVariable(AbstractVariable):
...
def sample(self) -> Tensor:
epsilon = torch.randn_like(self.sigma)
return self.mu + self.sigma * epsilon. # self.mu and self.sigma are trainable nn.Parameter objects
- Replacing the old
forward
with a new one:
def new_forward(self, input: torch.Tensor) -> torch.Tensor:
sampled_weight = self.weight_dist.sample()
sampled_bias = self.bias_dist.sample()
return F.linear(input, a, b)
...
layer.forward = new_forward.__get__(layer, nn.Module)
In this setting I don’t face any problem during network training i.e. self.mu and self.sigma are being updated.
However, I don’t want to reimplement a forward function as I showed in #2 (For nn.Linear
it is ok, but next I want to use also other, more complex layers). Rather I would like to use nn.Linear.forward
again. For that, I tried many options, but none of them worked:
- I did not replace
forward
withnew_forward
, but used hooks:
def forward_hook(module, inputs):
nn.utils.vector_to_parameters(module.weight_dist.sample().flatten(), module.weight)
nn.utils.vector_to_parameters(module.bias_dist.sample().flatten(), module.bias)
return inputs
...
layer.register_forward_pre_hook(forward_hook)
- Used the following
new_forward
:
def new_forward(self, input: torch.Tensor) -> torch.Tensor:
nn.utils.vector_to_parameters(self.weight_dist.sample().flatten(), self.weight)
nn.utils.vector_to_parameters(self.bias_dist.sample().flatten(), self.bias)
return nn.Linear.forward(self, input)
- Used different assignments:
nn.utils.vector_to_parameters(self.weight_dist.sample().flatten(), self.weight)
or
self.weight.data = self.weight_dist.sample()
or
self.weight = nn.Parameter(self.weight_dist.sample())
Could you help me with that? How to properly reuse forward
method in this setting?