I found this code
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.Tensor(X.rvs(m.weight.data.numel()))
values = values.view(m.weight.data.size())
m.weight.data.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
Can anyone explain what is happening here after isinstance
is matched with either nn.Conv2d
,nn.Linear
and nn.BatchNorm2d