Yes, the times are identical with and without pretrained. I’m using the latest PyPI versions of torch (1.4.0) and torchvision (0.5.0), and it happens on multiple machines.
Specifically, getting the truncated normal distribution samples via X.rvs(m.weight.numel()). Could it be that a change to scipy.stats slowed this down? Also, is it possible to perform the same truncated norm with torch’s built-in tensor operations?
If I change the weight initialization to something like
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
#values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
#values = values.view(m.weight.size())
foo = X.rvs(m.weight.numel())
values = torch.zeros_like(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
Then I still see the slowdown, which disappears if I remove the X.rvs() call. A single call to X.rvs() doesn’t take a particularly long time, but the loop iterates over ~300 layers.
It definitely seems like that scipy.stats issue is the same issue we’re seeing here. In the meantime, I’ve just removed the weight initialization, since I always use the pretrained weights anyways.
Uninstalling scipy version 1.4 and installing 1.3.3 fixed the issue for me. The loading time is much faster now. Just had to do pip install --upgrade scipy==1.3.3