I’m not sure about the tracing - but i was able to script it successfully with the following
class DecoderBlockLinkNet(torch.jit.ScriptModule):
def __init__(self, in_channels, n_filters):
super().__init__()
self.relu = nn.ReLU(inplace=True)
# B, C, H, W -> B, C/4, H, W
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
# B, C/4, H, W -> B, C/4, 2 * H, 2 * W
self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4,
stride=2, padding=1, output_padding=0)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
# B, C/4, H, W -> B, C, H, W
self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
self.norm3 = nn.BatchNorm2d(n_filters)
@torch.jit.script_method
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu(x)
return x
class UNet16(torch.jit.ScriptModule):
__constants__ = ["conv1", "conv2", "conv3", "conv4", "conv5"]
def __init__(self, num_classes=1, num_filters=32, pretrained=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network used
True - encoder pre-trained with VGG11
"""
super().__init__()
self.num_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder[0],
self.relu,
self.encoder[2],
self.relu)
self.conv2 = nn.Sequential(self.encoder[5],
self.relu,
self.encoder[7],
self.relu)
self.conv3 = nn.Sequential(self.encoder[10],
self.relu,
self.encoder[12],
self.relu,
self.encoder[14],
self.relu)
self.conv4 = nn.Sequential(self.encoder[17],
self.relu,
self.encoder[19],
self.relu,
self.encoder[21],
self.relu)
self.conv5 = nn.Sequential(self.encoder[24],
self.relu,
self.encoder[26],
self.relu,
self.encoder[28],
self.relu)
self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(64 + num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
@torch.jit.script_method
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(self.pool(conv1))
conv3 = self.conv3(self.pool(conv2))
conv4 = self.conv4(self.pool(conv3))
conv5 = self.conv5(self.pool(conv4))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
if self.num_classes > 1:
x_out = F.log_softmax(self.final(dec1), dim=1)
else:
x_out = self.final(dec1)
return x_out
model = UNet16()