Hi!
I created my own UNet model following some walkthrough tutorials, trying to be as close to original paper as possible, making some small changes as nb_classes, padding. Model looks like this:
def double_conv(in_c, out_c):
conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
return conv
def crop_img(tensor, target_tensor):
# square images
target_size = target_tensor.size()[2]
if tensor.size()[2] % 2 == 1:
tensor_size = tensor.size()[2]1
else:
tensor_size = tensor.size()[2]
delta = tensor_size  target_size
delta = delta // 2
return tensor[:, :, delta:tensor_sizedelta, delta:tensor_sizedelta]
class UNet(nn.Module):
def __init__(self, nb_classes):
super(UNet, self).__init__()
self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_conv_1 = double_conv(3, 64)
self.down_conv_2 = double_conv(64, 128)
self.down_conv_3 = double_conv(128, 256)
self.down_conv_4 = double_conv(256, 512)
self.down_conv_5 = double_conv(512, 1024)
## transposed convolutions
self.up_trans_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.up_conv_1 = double_conv(1024, 512)
self.up_trans_2 = nn.ConvTranspose2d(512, 256, 2, 2)
self.up_conv_2 = double_conv(512, 256)
self.up_trans_3 = nn.ConvTranspose2d(256, 128, 2, 2)
self.up_conv_3 = double_conv(256, 128)
self.up_trans_4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.up_conv_4 = double_conv(128, 64)
self.out = nn.Conv2d(64, nb_classes, 1)
def forward(self, image):
# encoder part
# input image
x1 = self.down_conv_1(image) # this is passed to decoder
# max pooling
x2 = self.max_pool_2x2(x1)
x3 = self.down_conv_2(x2) # this is passed to decoder
x4 = self.max_pool_2x2(x3)
x5 = self.down_conv_3(x4) # this is passed to decoder
x6 = self.max_pool_2x2(x5)
x7 = self.down_conv_4(x6) # this is passed to decoder
x8 = self.max_pool_2x2(x7)
x9 = self.down_conv_5(x8)
# decoder part
x = self.up_trans_1(x9)
y = crop_img(x7, x)
x = self.up_conv_1(torch.cat([x, y], 1))
x = self.up_trans_2(x)
y = crop_img(x5, x)
x = self.up_conv_2(torch.cat([x, y], 1))
x = self.up_trans_3(x)
y = crop_img(x3, x)
x = self.up_conv_3(torch.cat([x, y], 1))
x = self.up_trans_4(x)
y = crop_img(x1, x)
x = self.up_conv_4(torch.cat([x, y], 1))
x = self.out(x)
return x
Then before jumping to some larger data, I wanted to learn and understand how everything works, so I created images and masks using this github.
image.shape: (512, 512, 3)
image.min(), image.max(): 0.0 255.0
mask.shape: (6, 512, 512)
mask.min(), mask.max(): 0.0 1.0
So the masks are 6channel arrays, where each channel represents one color and sometimes channels overlay themselves, as on this example image below.
What I’m doing next is just preparing dummy train and evaluation, please forgive me if this code hurts your eyes
model = UNet(nb_classes=6)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e4)
#summary(model, input_size=(3, 128, 128))
for t in range(99):
image = load_image(f'data/images/image_{t}.jpg')
mask = np.load(f'data/masks/mask_{t}.npy')
# loading images as 3 channel RGB
image = np.expand_dims(image, 0)
image = image.transpose(0, 3, 1, 2)
input_data = torch.tensor(image)
# loading masks as numpy 6 channel arrays
mask_data = torch.tensor(mask)
# argmax to combine all channels?
y_true = torch.argmax(mask_data, dim=0)
# adding one BATCH dimension
y_true = y_true.unsqueeze(0)
y_pred = model(input_data)
loss = criterion(y_pred, y_true)
print(t, loss.item())
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
image = load_image('data/images/image_102.jpg')
image = np.expand_dims(image, 0)
image = image.transpose(0, 3, 1, 2)
input_data = torch.tensor(image)
pred = model(input_data)
pred = F.sigmoid(pred)
print(pred.shape)
np.save('pred', pred)
Summarizing, I have following questions:
 As I understand, in the multilabel segmentations layers can overlay (each pixel can be assigned to multiple objects), where in the multiclass segmentation, each pixel can be assigned only to one object?
So for the multiclass, we can have one channel mask (3x3) which can look like this:
[ 0 1 2  1 2 3  4 5 5 ]
We would prepare targets as [ batch_size, channel=1, height, width ]
, please correct me if I am wrong.
Which loss function should I use in this case?

How to proceed with the multilabel segmentation?

Right now the loss after ~ 40 loops drops to around 0.05. I’m using code from the linked github:
loaded_array = np.load('pred.npy')
loaded_array.shape
pred_rgb = [helper.masks_to_colorimg(x) for x in loaded_array]
plt.imshow(pred_rgb[0])
Masks for the 102 image look like this (after 99 loops):
99 loops
What am I doing wrong?