Yes, I used my custom UNet-like model.
class PartialConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel, bn=True, bias=False, sample="none-3", activation="relu"):
super().__init__()
self.bn = bn
self.kernel_size = kernel
self.in_channel = in_channels
if sample == "down-7":
# Kernel Size = 7, Stride = 2, Padding = 3
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 3, bias=bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 3, bias=False)
elif sample == "down-5":
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 2, bias=bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 2, bias=False)
elif sample == "down-3":
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 1, bias=bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 1, bias=False)
else:
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 1, 1, bias=bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 1, 1, bias=False)
nn.init.constant_(self.mask_conv.weight, 1.0)
# Initialize weight using Kaiming Initialization
# a: negative slope of relu set to 0, same as relu
# "fan_in" preserved variance from forward pass
nn.init.kaiming_normal_(self.input_conv.weight, a=0, mode="fan_in")
for param in self.mask_conv.parameters():
param.requires_grad = False
if bn:
# Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
# Applying BatchNorm2d layer after Conv will remove the channel mean
self.batch_normalization = nn.BatchNorm2d(out_channels)
if activation == "relu":
# Used between all encoding layers
self.activation = nn.ReLU()
elif activation == "leaky_relu":
# Used between all decoding layers (Leaky RELU with alpha = 0.2)
self.activation = nn.LeakyReLU(negative_slope=0.2)
def forward(self, input_x, mask):
# output = W^T dot (X .* M) + b
output = self.input_conv(input_x * mask)
# requires_grad = False
with torch.no_grad():
# mask = (1 dot M) + 0 = M
output_mask = self.mask_conv(mask)
if self.input_conv.bias is None:
output_bias = 0
else:
# Only the last layer has a bias term
# ** for ease of of quantization and for future 1-channel conversion,
# bias term is hardcoded
# output_bias = self.input_conv.bias.reshape(1, -1, 1, 1)
output_bias = 0.56
# Mask Update
updated_mask = torch.clamp(output_mask, 0, 1)
# Output image Update
num_scaling = torch.ones_like(output) * self.kernel_size * self.kernel_size
denom_scaling = output_mask / self.in_channel
scaling_factor = num_scaling / denom_scaling
updated_output = (output - output_bias) * scaling_factor + output_bias
updated_output = updated_output * updated_mask
if self.bn:
updated_output = self.batch_normalization(updated_output)
if hasattr(self, 'activation'):
updated_output = self.activation(updated_output)
return updated_output, updated_mask
class PartialConvUNet(nn.Module):
# 256 x 256 image input, 256 = 2^8
def __init__(self, input_size=256, layers=7):
if 2 ** (layers + 1) != input_size:
raise AssertionError
super().__init__()
self.freeze_enc_bn = False
self.layers = layers
# ======================= ENCODING LAYERS =======================
# 3x256x256 --> 64x128x128
self.encoder_1 = PartialConvLayer(3, 64, 7, bn=False, sample="down-7")
# 64x128x128 --> 128x64x64
self.encoder_2 = PartialConvLayer(64, 128, 5, sample="down-5")
# 128x64x64 --> 256x32x32
self.encoder_3 = PartialConvLayer(128, 256, 3, sample="down-3")
# 256x32x32 --> 512x16x16
self.encoder_4 = PartialConvLayer(256, 512, 3, sample="down-3")
# 512x16x16 --> 512x8x8 --> 512x4x4 --> 512x2x2
for i in range(5, layers + 1):
name = "encoder_{:d}".format(i)
setattr(self, name, PartialConvLayer(512, 512, 3, sample="down-3"))
# ======================= DECODING LAYERS =======================
# dec_7: UP(512x2x2) + 512x4x4(enc_6 output) = 1024x4x4 --> 512x4x4
# dec_6: UP(512x4x4) + 512x8x8(enc_5 output) = 1024x8x8 --> 512x8x8
# dec_5: UP(512x8x8) + 512x16x16(enc_4 output) = 1024x16x16 --> 512x16x16
for i in range(layers, 0, -1):
name = "decoder_{:d}".format(i)
setattr(self, name, PartialConvLayer(512 + 512, 512, 3, activation="leaky_relu"))
# UP(512x16x16) + 256x32x32(enc_3 output) = 768x32x32 --> 256x32x32
self.decoder_4 = PartialConvLayer(512 + 256, 256, 3, activation="leaky_relu")
# UP(256x32x32) + 128x64x64(enc_2 output) = 384x64x64 --> 128x64x64
self.decoder_3 = PartialConvLayer(256 + 128, 128, 3, activation="leaky_relu")
# UP(128x64x64) + 64x128x128(enc_1 output) = 192x128x128 --> 64x128x128
self.decoder_2 = PartialConvLayer(128 + 64, 64, 3, activation="leaky_relu")
# UP(64x128x128) + 3x256x256(original image) = 67x256x256 --> 3x256x256(final output)
self.decoder_1 = PartialConvLayer(64 + 3, 3, 3, bn=False, activation="", bias=True)
def forward(self, input_x, mask):
encoder_dict = {}
mask_dict = {}
key_prev = "h_0"
encoder_dict[key_prev], mask_dict[key_prev] = input_x, mask
# Encoder Path
for i in range(1, self.layers + 1):
encoder_key = "encoder_{:d}".format(i)
key = "h_{:d}".format(i)
# Passes input and mask through encoding layer
encoder_dict[key], mask_dict[key] = getattr(self, encoder_key)(encoder_dict[key_prev], mask_dict[key_prev])
key_prev = key
# Gets the final output data and mask from the encoding layers
# 512 x 2 x 2
out_key = "h_{:d}".format(self.layers)
out_data, out_mask = encoder_dict[out_key], mask_dict[out_key]
# Decoder Path
for i in range(self.layers, 0, -1):
encoder_key = "h_{:d}".format(i - 1)
decoder_key = "decoder_{:d}".format(i)
# Upsample to 2 times scale, matching dimensions of previous encoding layer output
out_data = F.interpolate(out_data, scale_factor=2)
out_mask = F.interpolate(out_mask, scale_factor=2)
# concatenate upsampled decoder output with encoder output of same H x W dimensions
# s.t. final decoding layer input will contain the original image
out_data = torch.cat([out_data, encoder_dict[encoder_key]], dim=1)
# also concatenate the masks
out_mask = torch.cat([out_mask, mask_dict[encoder_key]], dim=1)
# feed through decoder layers
out_data, out_mask = getattr(self, decoder_key)(out_data, out_mask)
return out_data
def train(self, mode=True):
super().train(mode)
if self.freeze_enc_bn:
for name, module in self.named_modules():
if isinstance(module, nn.BatchNorm2d) and "enc" in name:
# Sets batch normalization layers to evaluation mode
module.eval()
and just ran this:
# 1. FX Mode quantization
model_fp = PartialConvUNet()
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
# prepare (Insert observer)
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# calibrate
model_prepared.eval()
img = torch.ones([1, 3, 256, 256])
mask = torch.ones([1, 3, 256, 256])
model_prepared(img, mask)
Should I switch it to eager mode??