I am currently attempting to create a machine-learning model that does semantic segmentation, but I am having a lot of issues getting the model to learn. The model has been struggling to produce any useful results, so we decided to see if we could debug the issue by having the model learn on only a single example. We are currently feeding 1 image into the model, then running back prop and applying an update with the goal of overtraining the model to get it to learn a singular example. Even this is failing, and I am not sure where we are going wrong or why. I was hoping to get some intuition on what the model might be doing so I could figure out some next steps.
The current model is just a simple U-Net with skip connections, convolution layers, and max pooling layers. The data is NYUv2, which contains 13 different classes that are segmented. Thus, our network outputs an (N, 13, 512, 512) Tensor at the last layer, where each channel represents a class. Then, to create the segmentation, we apply an argmax across dimension 1. The problem is that as the network trains, overtime it starts to segment everything as 0 or 1, which is the background andthen walls. I would think that over time when training on only a single example, even a U-Net would be able to produce somewhat okay results. We have tried using cross entropy loss, binary cross entropy loss, L1 loss, mean square error loss, and dice loss. All to the same result.
Below is some of the code that we used for training along with the model outputs:
def convert_to_one_hot(labels, classes):
"""
Convert a segmentation image to one hot encoding
@params
segmentation_images: (batch_size, height, width) or (batch_size, 1, height, width)
num_classes: number of classes
@return
one_hot_images: (batch_size, num_classes, height, width)
"""
if len(segmentation_images.shape) == 4:
segmentation_images = segmentation_images.squeeze(1)
one_hot = torch.nn.functional.one_hot(segmentation_images.long(), num_classes=num_classes)
one_hot = torch.permute(one_hot, [0, 3, 1, 2]).float()
return one_hot
def convert_to_segmentation(one_hot_images):
"""
Convert a one hot encoding image to segmentation image
@params
one_hot_images: (batch_size, num_classes, height, width)
@return
segmentation_images: (batch_size, 1, height, width)
"""
return torch.argmax(one_hot_images, dim=1)
def train_Unet(model, dataloader, optimizer=None, loss_func=None, num_epochs=20, show_result_every=5):
total_losses = []
for epoch in range(num_epochs):
epoch_losses = []
for data in dataloader:
x = data[0].to(device=device).float()
y = data[1].to(device=device).float()
y = utils.convert_to_one_hot(y, 13)
optimizer.zero_grad()
output = model(x)
loss = loss_func(output, y)
loss.backward()
optimizer.step()
epoch_losses.append(loss.item())
total_losses.append(loss.item())
epoch_losses = torch.tensor(epoch_losses)
print(f"Epoch {epoch} loss: {torch.sum(epoch_losses)/(epoch_losses.shape[0])}")
return total_losses
data = next(iter(train_dl))
dataloader = [data]
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
loss_func = torch.nn.BCELoss().to(device=device)
train_Unet(model, dataloader, optimizer, loss_func=loss_func, num_epochs=1000)
Actual:
(I cant upload the input nor the exected as the forum wont let me, but this output has 2 colors only, while there should be far more)
I honestly am not sure where to even look for why this might be occuring. I would think that if im training on a single example, it should quickly overfit and produce results for exactly that example, but it doesn’t. Any help on this is greatly appreciated.