U-Net transfer learning

Hi, the last couple days I have been experimenting with the U Net model and doing transfer learning by feature extraction ( freezing layers of the model ).
I have successfully trained a U Net model from scratch for binary segmentation . My dataset consisted from 1100 slices of size 512 X 512.
Now I try to apply the transfer learning method by feature extraction and see what is the result of freezing different layers of the model . For that I use a new dataset of 300 slices with sizes 200 X 200.
Moreover, I have split my dataset to training and validation with a split percentage 0.9 .
With whichever combination of layers I choose to freeze I get an overestimation of segmentation .
I try to make an interpretation of that behaviour of my model . At first I thought that while using feature extraction I also need to tweak the hyperparameters of my model or the bad results are due to the size of my new dataset.
It seems now that just by applying the initial trained model to the new dataset without training give the best results (still not satisfactory for the application that I want ).

My query or maybe why i post this topic is to ask for a guidance or any input to this matter and maybe how you should approach such a task .
All comments are welcome.

Can you elaborate on what did you mean by “overestimation of segmentation”?
Do you mean its biased towards one particular class or so?

Yes correct . I work with microscopy images of a material with two phases . During prediction It makes an overestimation of one phase in comparison to the other .
I hope now is clearer.

Also I forgot to mention in my first message I use IoU as metric for tracking the progress of my model

Have you checked if your dataset is not biased towards any of the classes? Maybe one is more populated? One has much worse examples?

1 Like

Good one. Yes this can be one explanation of my problem. I just checked my dataset. The problem is that the phase that I want to segmentate is evolving dynamically with time. Thus when I choose my training dataset is really hard not to be biased
So what I try is to segmentate a spatial temporal dataset with U-net. If the phase of interest in my material develop relative stable with time I get good results but when the phase develops more ‘randomly’ I get but predictions. Maybe a more sophisticated model like a combination of LSTM and U net would help resolve that problem but for now I would like to avoid make my life more complicated.

If you dont want to account for the time variable in the model, you could try to dynamically increase the penalty on false positives of the out of balance class (some ratio). That should discourage the model from bias but you’d have to test it.

Not sure of how I can do it that. You mean to use a criterium and when I get ‘overestimation’ to stop training my model ?
Now I will try to use bigger dataset and try my phase in interest to be distributed more uniformly .

Well, I think the U-NETs are more then enough, I don’t know exactly how the data looks like, but if you want to try a real cannon, visual-transformers are always there. Although training a new one for this purpose might be very hard.
Can you say if data is analyzed as sequence? Or singular instances?
If sequence, than it is much harder, but if data is just an snap-shot image, you can always try to add some transformations to your images to increase the amount of training samples. This is superficial, but helps.

I will have a look in visual-transformers.
I would say that my data are sequence. I dont know how you define sequence and instances datasets.
But I believe that is a sequence as I have a phase, a (‘structure’ if you want ) that develops dynamically with time.

What I meant was if your base data is a photo or is it a several developing frames? you put one frame/slice at a time, and the next one does not connect to the previous in the process? or is the n+1 result based on n? I don’t understand the training procedure you have used, that’s why I ask. You may try to train the model on randomly ordered frames.

Ok thanks now I understood.
So you suggest that I can use maybe shuffle in me dataloader and see if that helps correct ?

Yes, generally if you expose only one frame at a time to the model, you should avoid using somehow ordered data in training. Model should be trained to recognize one frame and not be influenced by other entries. During the process of weights optimization, frames which are similar to each other in a sequence, may influence the way in which model learns features, and so on.
What I propose:

  • try shuffling training data,
  • try augmenting data (rotations, cropping, maybe slight color changes and introduce some noise into some of the pictures) use all available data, merge augmented and non-augmented data together,
  • try keeping data well balanced,
  • from what i understand you have changed last layers in the network, you don’t have much training data, so keep the structure shallow
  • try overfitting on set of 2-5 frames, to make sure you have no additional problems in your architecture

Ok thanks it is clear.

So again it is more about data preparation rather which model you train.

My Unet consists of 4 Convolutional blocks and 4 skip connections.

I have around 1000 images for training my first model from skratch and I would like to use transfer learning in a new task where my dataset for this task is 300 slices.

It seems that one of the limitations was that I had an unbalanced dataset.

Any advice on how to select a well balanced dataset

The limitation right now for generalizing my trained model is the lack of a workflow of how to select my training dataset. I just visually observe my images and I hope that they are balanced enough.

You canntry to upscale your images to 512x512 as during initial training

That would be really nice . But I am constrained from the nature of my images . I work with different datasets that each one has different image dimensions. I worked initially with the one that had the largest possible image size as I wanted to benefit from its large image size . But now I need to segmentate also the other remaining datasets that consist from images with smaller image size .

After preparing a new dataset with more even or balanced representation of my classes . I got better results . So in my case the problem was as it was mentioned above the lack of balance in my training dataset