I would like to obtain deterministic/reproducible weight initialization when working with say Con1d, while limiting the deterministic state to the weight [default] initialization part only and retaining the random state for the rest of the torch modules. Following, non-functional, code expresses the idea I am trying to achieve
gen1 = torch.Generator() gen1.manual_seed(123) conv = nn.Conv1d(16, 33, 3, stride=2) nn.init.kaiming_uniform_(conv.weight, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=gen1) gen2 = torch.Generator() gen2.manual_seed(123) conv2 = nn.Conv1d(16, 33, 3, stride=2) nn.init.kaiming_uniform_(conv2.weight, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=gen2) assert conv2.weight == conv1.weight
kaiming_uniform_ does not accept a generator. Is there a way to this?
This would make sure that the sampling from the distribution is same.