I am Working in a Multiclass Semantic Segmentation Project, which involves 4 classes (background not included) and a COCO Json as Data format (with an original dataset from 10 train images and 4 test images, but I also have a Consistent Data Augmentation Script where I Get 100 train images and 20 test images), Pytorch as framework and UNET Pretrained Model from SMP Library. When I Used DataSet and DataLoader methods to generate the batches for the model, those batches have the following shape:
batch_images, masks = batch batch_images.shape: torch.Size([1, 3, 1024, 1024]) type batch_images: <class 'torch.Tensor'> batch_images.dtype: torch.float32 batch_masks.shape: torch.Size([1, 4, 1024, 1024]) type batch_mask: <class 'torch.Tensor'> batch_mask.dtype: torch.float32
I also have a method inside the Dataset object to get the Class Frequencies (in a pixel number) and the class weights (calculated from inverse frequency method), where for Augmented dataset shows:
Class Frequencies_Train: [16707375. 3310340. 3782234. 3841229.] Class Weights_Train: tensor([0.1981, 1.0000, 0.8752, 0.8618]) Class Frequencies_Test: [1167281. 1984024. 1058666. 1753069.] Class Weights_Test: tensor([0.9070, 0.5336, 1.0000, 0.6039])
Then, I am using the following model a train Script:
num_classes = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weights_train = torch.tensor(class_weights_train).to(device)
class_weights_test = torch.tensor(class_weights_test).to(device)
encoder = "resnet101"
e_weights = "imagenet"
model_Unet = smp.Unet(
encoder_name=f"{encoder}",
encoder_weights=f"{e_weights}",
in_channels=3,
classes=num_classes,
)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.0001
weight_decay = 0.0005
optimizer = torch.optim.Adam(model_Unet.parameters(), lr=learning_rate, weight_decay=weight_decay)
model_Unet.to(device)
model_Unet.train()
#Training loop
session = 135
session_folder = f"training_session_{session}"
os.makedirs(session_folder, exist_ok=True)
num_epochs = 300
checkpoint_interval = 100
train_losses = []
train_iou = []
test_losses = []
test_iou = []
for epoch in range(num_epochs):
model_Unet.train()
train_loss = 0.0
intersection = 0
union = 0
for i, (images, masks) in enumerate(dataloader_train):
optimizer.zero_grad()
images = images.to(device)
masks = masks.squeeze(1).float().to(device)
outputs = model_Unet(images)
loss = criterion(outputs, masks)
class_weights = torch.tensor(class_weights_train, dtype=torch.float32, device=device)
weighted_loss = loss * class_weights
train_loss += torch.sum(weighted_loss).item()
loss.backward()
optimizer.step()
predicted_masks = torch.argmax(outputs, dim=1).float()
intersection += torch.sum(predicted_masks * masks).item()
union += torch.sum((predicted_masks + masks) > 0).item()
# Loss
train_loss += loss.item()
print(f'Epoch: {epoch+1}/{num_epochs}\t Iteration: {i+1}/{len(dataloader_train)}')
# Average Loss
train_loss /= len(dataloader_train.dataset)
# IoU
iou_train = intersection / union
train_losses.append(train_loss)
train_iou.append(iou_train)
print(f'Epoch: {epoch+1}/{num_epochs}\t Training Loss: {train_loss}\t IoU: {iou_train}')
model_Unet.eval()
test_loss = 0.0
intersection_test = 0
union_test = 0
with torch.no_grad():
for images, masks in dataloader_test:
images = images.to(device)
masks = masks.squeeze(1).float().to(device)
outputs = model_Unet(images)
loss = criterion(outputs, masks)
class_weights = torch.tensor(class_weights_test, dtype=torch.float32, device=device)
weighted_loss = loss * class_weights
test_loss += torch.sum(weighted_loss).item()
predicted_masks = torch.argmax(outputs, dim=1).float()
intersection_test += torch.sum(predicted_masks * masks).item()
union_test += torch.sum((predicted_masks + masks) > 0).item()
# Average Test Loss
test_loss /= len(dataloader_test.dataset)
# IoU
iou_test = intersection_test / union_test
test_losses.append(test_loss)
test_iou.append(iou_test)
The IoU usually is not consistent:
- Using the original Database, which was more imbalanced, the IoU Reached 0.42 and then dropped to 0 un 1 Epoch (whit this I mean that, for example, in the Epoch 290 the IoU value is 0.4 and then in the epoch 291 the IoU is 0.
- Using the “Cleaned” Database, talking about the Augmented and Non-Augmented, The IoU have a similar behavior but with a Lower Value (Reaches 0.3 before drops to 0).
About the Inference, Is in a “mid-point”. I mean, it isn’t perfect but it’s not at all bad (segmentate the background as the First Class and I am not including Background in the classes).
I’ve thinking in implement the DICE LOSS function from SMP repository, but I’m not sure how to do it, and if I must use the class Weights for that.
- Do you have any tips for Fix and improve the IoU and Inference?