Constrained parameters in torch

The way to specify a constrained parameter in pyro:

weights = pyro.param(‘weights’, torch.randn(K), constraint = constraints.simplex)

  • whats the way to do it with torch.nn.Parameter?

IIRC pyro just imports constraints from pytorch.

Check constraint_registry doc, there is a code example. Alternatively, you can just manually transform parameters, e.g. scale = F.softplus(self.raw_scale)