Is it ok to put some nn.Module in the forward() function of my class?
will they be correctly registered and the gradients be computed correctly?
For example, is my code below correct?
(my dropout, pooling and upsampling are in the forward() method and not in the init(). Or should i use the functional form?)
def make_conv_bn_prelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
return [
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.PReLU(out_channels),
]
def make_flat(out):
flat = nn.AdaptiveAvgPool2d(1)(out)
flat = flat.view(flat.size(0), -1)
return flat
class MyNet(nn.Module):
def __init__(self, in_shape, num_classes):
super(MyNet, self).__init__()
in_channels, height, width = in_shape
stride=1
self.preprocess = nn.Sequential(
*make_conv_bn_prelu(in_channels, 8, kernel_size=1, stride=1, padding=0 ),
*make_conv_bn_prelu(8, 8, kernel_size=1, stride=1, padding=0 ),
*make_conv_bn_prelu(8, 8, kernel_size=1, stride=1, padding=0 ),
*make_conv_bn_prelu(8, 8, kernel_size=1, stride=1, padding=0 ),
)
self.down0 = nn.Sequential(
*make_conv_bn_prelu( 8, 32),
*make_conv_bn_prelu(32, 32, kernel_size=1, stride=1, padding=0 ),
)
self.down1 = nn.Sequential(
*make_conv_bn_prelu(32, 32),
*make_conv_bn_prelu(32, 32, kernel_size=1, stride=1, padding=0 ),
)
self.down2 = nn.Sequential(
*make_conv_bn_prelu(32, 32),
*make_conv_bn_prelu(32, 32, kernel_size=1, stride=1, padding=0 ),
)
self.down3 = nn.Sequential(
*make_conv_bn_prelu(32, 32),
*make_conv_bn_prelu(32, 32, kernel_size=1, stride=1, padding=0 ),
)
self.up2 = nn.Sequential(
*make_conv_bn_prelu(32+32, 32, kernel_size=1, stride=1, padding=0 ),
*make_conv_bn_prelu(32, 32),
)
self.up1 = nn.Sequential(
*make_conv_bn_prelu(32+32, 32, kernel_size=1, stride=1, padding=0 ),
*make_conv_bn_prelu(32, 32),
)
self.block = nn.Sequential(
*make_linear_bn_prelu(32+32+32, 512),
*make_linear_bn_prelu(512, 512),
)
self.logit = nn.Linear(512, num_classes)
def forward(self, x):
out = self.preprocess(x)
down0 = self.down0(out)
out = nn.MaxPool2d(kernel_size=2, stride=2)(down0)
down1 = self.down1(out)
out = nn.MaxPool2d(kernel_size=2, stride=2)(down1)
down2 = self.down2(out)
out = nn.MaxPool2d(kernel_size=2, stride=2)(down2)
out = self.down3(out)
flat3 = make_flat(out)
up2 = nn.UpsamplingNearest2d(scale_factor=2)(out)
up2 = torch.cat([down2, up2],1)
out = self.up2(up2)
flat2 = make_flat(out)
up1 = nn.UpsamplingNearest2d(scale_factor=2)(out)
up1 = torch.cat([down1, up1],1)
out = self.up1(up1)
flat1 = make_flat(out)
out = torch.cat([flat1, flat2, flat3],1)
out = nn.Dropout(p=0.10)(out)
out = self.block(out)
out = self.logit(out)
return out