ERFNET full network definition for Pytorch
Sept 2017
Eduardo Romera
#######################
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import numpy as np
from torch.quantization import QuantStub, DeQuantStub
import os
class DownsamplerBlock (nn.Module):
def init(self, ninput, noutput, caller_name=’’):
super().init()
self.conv = nn.Conv2d(ninput, noutput - ninput, (3, 3), stride=2, padding=1, bias=True)
self.pool = nn.MaxPool2d(2, stride=2)
self.caller_name = caller_name
def forward(self, input):
output = self.conv(input)
conv_output = F.relu(output)
output = self.pool(input)
pool_output = F.relu(output)
output = torch.cat([conv_output, pool_output], 1)
return output
class non_bottleneck_1d (nn.Module):
def init(self, chann, dropprob, dilated, caller_name=’’):
super().init()
self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
dilation=(dilated, 1))
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
dilation=(1, dilated))
self.caller_name = caller_name
def hook(self, conv, input, out):
print("in hook")
def forward(self, input):
output = self.conv3x1_1(input)
output = F.relu(output)
output = self.conv1x3_1(output)
output = F.relu(output)
output = self.conv3x1_2(output)
output = F.relu(output)
output = self.conv1x3_2(output)
output = output + input
return F.relu(output) # +input = identity (residual connection)
class Encoder(nn.Module):
def init(self, num_classes):
super().init()
self.initial_block = DownsamplerBlock(4, 16)
self.layers = nn.ModuleList()
self.layers.append(DownsamplerBlock(16, 64))
for x in range(0, 5): # 5 times
self.layers.append(non_bottleneck_1d(64, 0.03, 1))
self.layers.append(DownsamplerBlock(64, 128))
for x in range(0, 2): # 2 times
self.layers.append(non_bottleneck_1d(128, 0.3, 2))
self.layers.append(non_bottleneck_1d(128, 0.3, 4))
self.layers.append(non_bottleneck_1d(128, 0.3, 8))
# only for encoder mode:
self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)
def forward(self, input, predict=torch.tensor(0)):
output = self.initial_block(input)
for layer in self.layers:
output = layer(output)
if predict:
output = self.output_conv(output)
return output
class UpsamplerBlock(nn.Module):
def init(self, ninput, noutput, caller_name=’’):
super().init()
self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
self.caller_name = caller_name
def forward(self, input):
output = self.conv(input)
return F.relu(output)
class Decoder(nn.Module):
def init(self, num_classes):
super().init()
self.layers = nn.ModuleList()
self.layers.append(UpsamplerBlock(128, 64))
self.layers.append(non_bottleneck_1d(64, 0, 1))
self.layers.append(non_bottleneck_1d(64, 0, 1))
self.layers.append(UpsamplerBlock(64, 16))
self.layers.append(non_bottleneck_1d(16, 0, 1))
self.layers.append(non_bottleneck_1d(16, 0, 1))
self.output_conv = nn.ConvTranspose2d(16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True)
def forward(self, input):
output = input
for layer in self.layers:
output = layer(output)
output = self.output_conv(output)
return output
ERFNet
class ERFNet(nn.Module):
def init(self, num_classes, encoder=None): # use encoder to pass pretrained encoder
super().init()
if encoder == None:
self.encoder = Encoder(num_classes)
else:
self.encoder = encoder
self.decoder = Decoder(num_classes)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, input, only_encode=torch.tensor(0)):
input = self.quant(input)
if only_encode:
x = self.encoder.forward(input, predict=torch.tensor(1))
x = self.dequant(x)
return (x)
else:
output = self.encoder(input) # predict=False by default
x = self.decoder.forward(output)
x = self.dequant(x)
return (x)
if name == ‘main’:
print(os.getcwd())
print(cur_quant_dir)