Hello,
After adding adding a parametrization module with:
torch.nn.utils.parametrize.register_parametrization(
self.conv, "weight", _SumOne()
)
, when does the forward()
function of _SumOne()
get called? The forward
function makes is so that the sum of the weights is 1:
def forward(self, weight: torch.Tensor) -> torch.Tensor:
return torch.divide(weight, weight.sum())
Is it called when the weights are accessed (in the forward and backward passes) or only when the weights are modified (in the backward pass)?
Note: self.conv
is a torch.nn.Conv2d
Thanks.