Segmentation Model failing to learn even on single overtrained example

I am currently attempting to create a machine-learning model that does semantic segmentation, but I am having a lot of issues getting the model to learn. The model has been struggling to produce any useful results, so we decided to see if we could debug the issue by having the model learn on only a single example. We are currently feeding 1 image into the model, then running back prop and applying an update with the goal of overtraining the model to get it to learn a singular example. Even this is failing, and I am not sure where we are going wrong or why. I was hoping to get some intuition on what the model might be doing so I could figure out some next steps.

The current model is just a simple U-Net with skip connections, convolution layers, and max pooling layers. The data is NYUv2, which contains 13 different classes that are segmented. Thus, our network outputs an (N, 13, 512, 512) Tensor at the last layer, where each channel represents a class. Then, to create the segmentation, we apply an argmax across dimension 1. The problem is that as the network trains, overtime it starts to segment everything as 0 or 1, which is the background andthen walls. I would think that over time when training on only a single example, even a U-Net would be able to produce somewhat okay results. We have tried using cross entropy loss, binary cross entropy loss, L1 loss, mean square error loss, and dice loss. All to the same result.

Below is some of the code that we used for training along with the model outputs:

def convert_to_one_hot(labels, classes):
  """
  Convert a segmentation image to one hot encoding

  @params
  segmentation_images: (batch_size, height, width) or (batch_size, 1, height, width)
  num_classes: number of classes

  @return
  one_hot_images: (batch_size, num_classes, height, width)
  """

  if len(segmentation_images.shape) == 4:
    segmentation_images = segmentation_images.squeeze(1)

  one_hot = torch.nn.functional.one_hot(segmentation_images.long(), num_classes=num_classes)
  one_hot = torch.permute(one_hot, [0, 3, 1, 2]).float()

  return one_hot


def convert_to_segmentation(one_hot_images):
  """
  Convert a one hot encoding image to segmentation image

  @params
  one_hot_images: (batch_size, num_classes, height, width)

  @return
  segmentation_images: (batch_size, 1, height, width)
  """
  return torch.argmax(one_hot_images, dim=1)

def train_Unet(model, dataloader, optimizer=None, loss_func=None, num_epochs=20, show_result_every=5):
    
    total_losses = []

    for epoch in range(num_epochs):
        epoch_losses = []

        for data in dataloader:
            x = data[0].to(device=device).float()
            y = data[1].to(device=device).float()

            y = utils.convert_to_one_hot(y, 13)

            optimizer.zero_grad()
            output = model(x)

            loss = loss_func(output, y)
            loss.backward()
            optimizer.step()

            epoch_losses.append(loss.item())
            total_losses.append(loss.item())

        epoch_losses = torch.tensor(epoch_losses)
        
        print(f"Epoch {epoch} loss: {torch.sum(epoch_losses)/(epoch_losses.shape[0])}")
    
    return total_losses

data = next(iter(train_dl))
dataloader = [data] 
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
loss_func = torch.nn.BCELoss().to(device=device)

train_Unet(model, dataloader, optimizer, loss_func=loss_func, num_epochs=1000)

Actual:
image

(I cant upload the input nor the exected as the forum wont let me, but this output has 2 colors only, while there should be far more)

I honestly am not sure where to even look for why this might be occuring. I would think that if im training on a single example, it should quickly overfit and produce results for exactly that example, but it doesn’t. Any help on this is greatly appreciated.

So one thing to keep in mind is that you’re doing 250.000 (one per per-pixel) classifications here. What is the distribution of the classes in your examples? In my experience, imbalance between your 13 classes can be a problem with that. I sometimes mention a section in our book that describes this.

Best regards

Thomas