Hello there,
I have run into very similar issue when converting other model to torchscript. At first I thought it freezes on second batch. IMO the problem is connected with BatchNorm, here is code that should explain it:
import torch
from tqdm import tqdm
import time
from typing import Type
from argparse import ArgumentParser
class Layer(torch.nn.Module):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super().__init__()
self.num_input_features = num_input_features
self.num_output_features = num_output_features
class WithBN(Layer):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super().__init__(num_input_features, num_output_features)
self.bn = torch.nn.BatchNorm2d(num_input_features, )
self.conv = torch.nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.bn(x)
x = self.conv(x)
return x
class WithoutBN(Layer):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super().__init__(num_input_features, num_output_features)
self.conv = torch.nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
return x
class DeepNetwork(torch.nn.Module):
def __init__(self, num_input_features: int, num_layers: int, layer: Type[Layer], growth: int) -> None:
super().__init__()
for i in range(num_layers):
num_output_features = num_input_features + growth
l = layer(num_input_features, num_output_features)
self.add_module("denselayer%d" % (i + 1), l)
num_input_features = num_output_features
def forward(self, x: torch.Tensor) -> torch.Tensor:
for _, layer in self.named_children():
x = layer(x)
return x
def loop(model: torch.nn.Module, use_tqdm=True) -> None:
model = torch.jit.script(model, glob_input)
with torch.no_grad(): # culprit
for i in tqdm(range(3), colour='green', leave=False, disable=not use_tqdm):
start = time.time()
model(glob_input)
stop = time.time()
print(f"It {i}th time:{stop-start:8>.2f}s")
if __name__ == '__main__':
global device
global glob_input
parser = ArgumentParser()
parser.add_argument('--gpu', action='store_true', help='use gpu')
parser.add_argument('--growth', type=int, default=16, help='growth rate')
parser.add_argument('--tqdm', action='store_true', help='use tqdm')
args = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() and args.gpu else "cpu")
num_layers = 99
growth = args.growth
num_input_features = 64
glob_input = torch.rand(2, num_input_features, 224, 224).to(device)
print("Model with BatchNorm")
model = DeepNetwork(num_input_features, num_layers, WithBN, growth).to(device)
model.eval()
loop(model, use_tqdm=args.tqdm)
print("Model without BatchNorm")
model = DeepNetwork(num_input_features, num_layers, WithoutBN, growth).to(device)
model.eval()
loop(model, use_tqdm=args.tqdm)
I tested my hypothesis with networks of different widths, on cpu and gpu respectively. Here are outputs with commands at the top:
python3 freeze_example.py --growth=2
Model with BatchNorm
It 0th time:4.08s
It 1th time:12.35s
It 2th time:3.77s
Model without BatchNorm
It 0th time:2.96s
It 1th time:2.93s
It 2th time:2.88s
python3 freeze_example.py --growth=4
Model with BatchNorm
It 0th time:7.05s
It 1th time:15.53s
It 2th time:6.59s
Model without BatchNorm
It 0th time:5.95s
It 1th time:5.76s
It 2th time:5.53s
python3 freeze_example.py --growth=8
Model with BatchNorm
It 0th time:13.31s
It 1th time:21.57s
It 2th time:13.02s
Model without BatchNorm
It 0th time:11.77s
It 1th time:10.68s
It 2th time:10.75s
python3 freeze_example.py --growth=2 --gpu
Model with BatchNorm
It 0th time:2.29s
It 1th time:8.38s
It 2th time:0.01s
Model without BatchNorm
It 0th time:0.05s
It 1th time:0.05s
It 2th time:0.01s
python3 freeze_example.py --growth=16 --gpu
Model with BatchNorm
It 0th time:3.03s
It 1th time:8.32s
It 2th time:0.01s
Model without BatchNorm
It 0th time:0.05s
It 1th time:0.05s
It 2th time:0.00s
python3 freeze_example.py --growth=64 --gpu
Model with BatchNorm
It 0th time:13.77s
It 1th time:8.44s
It 2th time:0.01s
Model without BatchNorm
It 0th time:8.74s
It 1th time:0.05s
It 2th time:0.00s
I tested it on docker with:
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
RUN pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116