The way to specify a constrained parameter in pyro:
weights = pyro.param(‘weights’, torch.randn(K), constraint = constraints.simplex)
- whats the way to do it with torch.nn.Parameter?
The way to specify a constrained parameter in pyro:
weights = pyro.param(‘weights’, torch.randn(K), constraint = constraints.simplex)
IIRC pyro just imports constraints from pytorch.
Check constraint_registry doc, there is a code example. Alternatively, you can just manually transform parameters, e.g. scale = F.softplus(self.raw_scale)