Loss not converging - satellite image binary segmentation with TernausNet

Hey there,

this is my first post, feel free to critic my style.

I am trying to do binary segmentation of roofs in satellite images. My dataset is small, I have only 20 images for train and 4 for validation. I used a google colab notebook from albumentations as a template, which was used for binary segmentation of animals with the The Oxford-IIIT Pet Dataset.

My notebook can be found here. It uses the library TernausNet, which uses pretrained UNet models for the semantic segmentation task.
When I run train, the loss is not really converging. It sometimes goes towards positive or negative infinity. I already tried to change the learning rate which helped a bit but not satisfactory.

Following are major blocks of my notebook.


train_transform = A.Compose(
        A.Resize(256, 256),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
train_dataset = RoofDataset(train_images_filenames, images_directory, masks_directory, transform=train_transform,)

val_transform = A.Compose(
    [A.Resize(256, 256), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2()]
val_dataset = RoofDataset(val_images_filenames, images_directory, masks_directory, transform=val_transform,)


def train(train_loader, model, criterion, optimizer, epoch, params):
    metric_monitor = MetricMonitor()
    stream = tqdm(train_loader)
    for i, (images, target) in enumerate(stream, start=1):
        images = images.to(params["device"], non_blocking=True)
        target = target.to(params["device"], non_blocking=True)
        output = model(images).squeeze(1)
        loss = criterion(output, target)
        metric_monitor.update("Loss", loss.item())
            "Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)


def validate(val_loader, model, criterion, epoch, params):
    metric_monitor = MetricMonitor()
    stream = tqdm(val_loader)
    with torch.no_grad():
        for i, (images, target) in enumerate(stream, start=1):
            images = images.to(params["device"], non_blocking=True)
            target = target.to(params["device"], non_blocking=True)
            output = model(images).squeeze(1)
            loss = criterion(output, target)
            metric_monitor.update("Loss", loss.item())
                "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)

Create model:

def create_model(params):
    model = getattr(ternausnet.models, params["model"])(pretrained=True)
    model = model.to(params["device"])
    return model

Train and validate:

def train_and_validate(model, train_dataset, val_dataset, params):
    train_loader = DataLoader(
    val_loader = DataLoader(
    criterion = nn.BCEWithLogitsLoss().to(params["device"])
    optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
    for epoch in range(1, params["epochs"] + 1):
        train(train_loader, model, criterion, optimizer, epoch, params)
        validate(val_loader, model, criterion, epoch, params)
    return model


params = {
    "model": "UNet11",
    "device": "cuda",
    "lr": 0.0000025,
    "batch_size": 16,
    "num_workers": 4,
    "epochs": 20,

Train and val loss:

Epoch: 1. Train.      Loss: 2.441: 100%|██████████| 2/2 [00:00<00:00,  2.20it/s]
Epoch: 1. Validation. Loss: 1.697: 100%|██████████| 1/1 [00:00<00:00,  2.66it/s]
Epoch: 2. Train.      Loss: 1.892: 100%|██████████| 2/2 [00:00<00:00,  2.35it/s]
Epoch: 2. Validation. Loss: 1.639: 100%|██████████| 1/1 [00:00<00:00,  2.77it/s]
Epoch: 3. Train.      Loss: 1.989: 100%|██████████| 2/2 [00:00<00:00,  2.33it/s]
Epoch: 3. Validation. Loss: 1.580: 100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
Epoch: 4. Train.      Loss: 1.748: 100%|██████████| 2/2 [00:00<00:00,  2.34it/s]
Epoch: 4. Validation. Loss: 1.522: 100%|██████████| 1/1 [00:00<00:00,  2.60it/s]
Epoch: 5. Train.      Loss: 1.840: 100%|██████████| 2/2 [00:00<00:00,  2.36it/s]
Epoch: 5. Validation. Loss: 1.464: 100%|██████████| 1/1 [00:00<00:00,  2.76it/s]
Epoch: 6. Train.      Loss: 1.578: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
Epoch: 6. Validation. Loss: 1.405: 100%|██████████| 1/1 [00:00<00:00,  2.58it/s]
Epoch: 7. Train.      Loss: 1.844: 100%|██████████| 2/2 [00:00<00:00,  2.34it/s]
Epoch: 7. Validation. Loss: 1.347: 100%|██████████| 1/1 [00:00<00:00,  2.70it/s]
Epoch: 8. Train.      Loss: 1.780: 100%|██████████| 2/2 [00:00<00:00,  2.35it/s]
Epoch: 8. Validation. Loss: 1.288: 100%|██████████| 1/1 [00:00<00:00,  2.77it/s]
Epoch: 9. Train.      Loss: 1.882: 100%|██████████| 2/2 [00:00<00:00,  2.33it/s]
Epoch: 9. Validation. Loss: 1.230: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Epoch: 10. Train.      Loss: 1.725: 100%|██████████| 2/2 [00:00<00:00,  2.39it/s]
Epoch: 10. Validation. Loss: 1.171: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Epoch: 11. Train.      Loss: 1.366: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
Epoch: 11. Validation. Loss: 1.113: 100%|██████████| 1/1 [00:00<00:00,  2.77it/s]
Epoch: 12. Train.      Loss: 1.295: 100%|██████████| 2/2 [00:00<00:00,  2.35it/s]
Epoch: 12. Validation. Loss: 1.055: 100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
Epoch: 13. Train.      Loss: 1.367: 100%|██████████| 2/2 [00:00<00:00,  2.32it/s]
Epoch: 13. Validation. Loss: 0.996: 100%|██████████| 1/1 [00:00<00:00,  2.73it/s]
Epoch: 14. Train.      Loss: 1.376: 100%|██████████| 2/2 [00:00<00:00,  2.30it/s]
Epoch: 14. Validation. Loss: 0.937: 100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
Epoch: 15. Train.      Loss: 0.874: 100%|██████████| 2/2 [00:00<00:00,  2.35it/s]
Epoch: 15. Validation. Loss: 0.878: 100%|██████████| 1/1 [00:00<00:00,  2.74it/s]
Epoch: 16. Train.      Loss: 1.098: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
Epoch: 16. Validation. Loss: 0.819: 100%|██████████| 1/1 [00:00<00:00,  2.77it/s]
Epoch: 17. Train.      Loss: 0.679: 100%|██████████| 2/2 [00:00<00:00,  2.30it/s]
Epoch: 17. Validation. Loss: 0.759: 100%|██████████| 1/1 [00:00<00:00,  2.63it/s]
Epoch: 18. Train.      Loss: 0.898: 100%|██████████| 2/2 [00:00<00:00,  2.23it/s]
Epoch: 18. Validation. Loss: 0.699: 100%|██████████| 1/1 [00:00<00:00,  2.63it/s]
Epoch: 19. Train.      Loss: 0.813: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
Epoch: 19. Validation. Loss: 0.639: 100%|██████████| 1/1 [00:00<00:00,  2.68it/s]
Epoch: 20. Train.      Loss: 0.625: 100%|██████████| 2/2 [00:00<00:00,  2.27it/s]
Epoch: 20. Validation. Loss: 0.580: 100%|██████████| 1/1 [00:00<00:00,  2.61it/s]

As you can see, the loss doesn’t get really low. If I run for more epochs, it goes towards infinity again.
After this run I get predictions like this, which at least are not completely random:

If I run it again and have a worse starting point, loss will go towards negative or positive infinity and predicted masks either all zeroes ore ones.

Any help is appreciated :slight_smile: