Hi, this is my first time building a model with Pytorch, so I am translating a unit from TensorFlow. Thus, I was wondering if someone could give me a sanity check that the model looks valid:
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, init_features=64, pooling_steps=2):
super(UNet, self).__init__()
features = init_features
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.upconv = nn.ModuleList()
for i in range(pooling_steps):
self.encoders.append(UNet._block(in_channels if i==0 else features, features * (2**i), name=f"enc{i+1}"))
self.decoders.insert(0, UNet._block(features * (2**(i+1)), features * (2**i), name=f"dec{i+1}"))
self.upconv.insert(0, nn.ConvTranspose2d(features * (2**(i+1)), features * (2**i), kernel_size=2, stride=2))
self.bottleneck = UNet._block(features * (2**(pooling_steps-1)), features * (2**pooling_steps), name="bottleneck")
self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)
def forward(self, x):
encs = []
for i, encoder in enumerate(self.encoders):
x = encoder(x)
encs.append(x)
x = self.pool(x)
x = self.bottleneck(x)
for i, decoder in enumerate(self.decoders):
x = self.upconv[i](x)
x = torch.cat((x, encs[-(i+1)]), dim=1)
x = decoder(x)
return torch.sigmoid(self.conv(x))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True),
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True),
)
Also, I was wondering what strategy could I follow in order to speed up inference. I read that it is possible to use half-precision (16 bits), but I am not sure how to implement it. Does it go with the tensor, or is it an attribute of the model?
Any other trick will be welcome.