Hi there, I was wondering if somebody could help me with semantic segmentation.
I am using the segmentation_models_pytorch library to train a Unet on the VOC2012 dataset.
I have not trained semantic segmentation models before, so I am not sure what form my data should be in. Specifically, I am not sure what loss function to use, and what format my data needs to be in to go into that loss function.
So far:
The input to my network is a bunch of images in the form:
[B, C, H, W]
This is currently [12, 3, 512, 512]
The output from my network is in the form:
[B, NUMCLASSES, H, W]
This is currently [12, 20, 512, 512]
I am loading in my normal images and segmentation truths using PIL.
So my input images have 3 channels, and my segmentation images have 1 channel (in PIL’s P mode).
So my input image goes into the network and outputs a shape of [12, 20, 512, 512]
My ground truth images are in the shape of [12, 1, 512, 512]
My question:
What loss function should I use, and what format / shapes should my data be in?
Is is possible to just shove my outputs into a loss function as they are and calculate a loss, or do I need to reshape them in any way?
Do I have to calculate an output prediction with output.argmax(1) for input into a loss function?
Thanks for your time.