In xavier_normal_
, we can manually set the gain
.
But in kaiming_normal_
, the gain is calculated by gain = calculate_gain(nonlinearity, a)
inside the function.
Is there any way we can also manually set gain
for kaiming_normal_
? If not, may I know the reason behind this design?
Many thanks!
I guess the current interface might be more convenient for standard use cases.
If you want to manually define the gain
, you could create your own custom function e.g. as:
fan = torch.nn.init._calculate_correct_fan(tensor, mode='fan_in')
gain = ... # your gain calculation
std = gain / math.sqrt(fan)
with torch.no_grad():
tensor.normal_(0, std)
1 Like
Thanks! I will have a try.