“”" This module creates SRNet model.“”"
import torch
from torch import Tensor
from torch import nn
from model.utils import Type1, Type2, Type3, Type4
class Srnet(nn.Module):
“”“This is SRNet model class.”“”
def __init__(self) -> None:
"""Constructor."""
super().__init__()
self.type1s = nn.Sequential(Type1(1, 64), Type1(64, 16))
self.type2s = nn.Sequential(
Type2(16, 16),
Type2(16, 16),
Type2(16, 16),
Type2(16, 16),
Type2(16, 16),
)
self.type3s = nn.Sequential(
Type3(16, 16),
Type3(16, 64),
Type3(64, 128),
Type3(128, 256),
)
self.type4 = Type4(256, 512)
self.dense = nn.Linear(512, 2)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, inp: Tensor) -> Tensor:
"""Returns logits for input images.
Args:
inp (Tensor): input image tensor of shape (Batch, 1, 256, 256)
Returns:
Tensor: Logits of shape (Batch, 2)
"""
out = self.type1s(inp)
out = self.type2s(out)
out = self.type3s(out)
out = self.type4(out)
out = out.view(out.size(0), -1)
out = self.dense(out)
return self.softmax(out)
class ConvBn(nn.Module):
“”“Provides utility to create different types of layers.”“”
def __init__(self, in_channels: int, out_channels: int) -> None:
"""Constructor.
Args:
in_channels (int): no. of input channels.
out_channels (int): no. of output channels.
"""
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, inp: Tensor) -> Tensor:
"""Returns Conv2d followed by BatchNorm.
Returns:
Tensor: Output of Conv2D -> BN.
"""
return self.batch_norm(self.conv(inp))
class Type1(nn.Module):
“”“Creates type 1 layer of SRNet.”“”
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.convbn = ConvBn(in_channels, out_channels)
self.relu = nn.ReLU()
def forward(self, inp: Tensor) -> Tensor:
"""Returns type 1 layer of SRNet.
Args:
inp (Tensor): input tensor.
Returns:
Tensor: Output of type 1 layer.
"""
return self.relu(self.convbn(inp))
class Type2(nn.Module):
“”“Creates type 2 layer of SRNet.”“”
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.type1 = Type1(in_channels, out_channels)
self.convbn = ConvBn(in_channels, out_channels)
def forward(self, inp: Tensor) -> Tensor:
"""Returns type 2 layer of SRNet.
Args:
inp (Tensor): input tensor.
Returns:
Tensor: Output of type 2 layer.
"""
return inp + self.convbn(self.type1(inp))
class Type3(nn.Module):
“”“Creates type 3 layer of SRNet.”“”
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=2,
padding=0,
bias=True,
)
self.batch_norm = nn.BatchNorm2d(out_channels)
self.type1 = Type1(in_channels, out_channels)
self.convbn = ConvBn(out_channels, out_channels)
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, inp: Tensor) -> Tensor:
"""Returns type 3 layer of SRNet.
Args:
inp (Tensor): input tensor.
Returns:
Tensor: Output of type 3 layer.
"""
out = self.batch_norm(self.conv1(inp))
out1 = self.pool(self.convbn(self.type1(inp)))
return out + out1
class Type4(nn.Module):
“”“Creates type 4 layer of SRNet.”“”
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.type1 = Type1(in_channels, out_channels)
self.convbn = ConvBn(out_channels, out_channels)
self.gap = nn.AdaptiveAvgPool2d(output_size=1)
def forward(self, inp: Tensor) -> Tensor:
"""Returns type 4 layer of SRNet.
Args:
inp (Tensor): input tensor.
Returns:
Tensor: Output of type 4 layer.
"""
return self.gap(self.convbn(self.type1(inp)))
“”“This module provides utility function for training.”“”
import os
import re
from typing import Any, Dict
import torch
from torch import nn
from opts.options import arguments
opt = arguments()
def saver(state: Dict[str, float], save_dir: str, epoch: int) → None:
torch.save(state, save_dir + “net_” + str(epoch) + “.pt”)
def latest_checkpoint() → int:
“”“Returns latest checkpoint.”“”
if os.path.exists(opt.checkpoints_dir):
all_chkpts = “”.join(os.listdir(opt.checkpoints_dir))
if len(all_chkpts) > 0:
latest = max(map(int, re.findall(“\d+”, all_chkpts)))
else:
latest = None
else:
latest = None
return latest
def adjust_learning_rate(optimizer: Any, epoch: int) → None:
“”“Sets the learning rate to the initial learning_rate and decays by 10
every 30 epochs.”“”
learning_rate = opt.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group[“lr”] = learning_rate
Weight initialization for conv layers and fc layers
def weights_init(param: Any) → None:
“”“Initializes weights of Conv and fully connected.”“”
if isinstance(param, nn.Conv2d):
nn.init.kaiming_normal_(param.weight, mode='fan_out', nonlinearity='relu')
if param.bias is not None:
nn.init.constant_(param.bias, 0.2)
elif isinstance(param, nn.BatchNorm2d):
nn.init.constant_(param.weight, 1)
nn.init.constant_(param.bias, 0)
elif isinstance(param, nn.Linear):
nn.init.normal_(param.weight, 0, 0.01)
nn.init.constant_(param.bias, 0)