Constrain segmentation with additional loss of a pretrained model

Hello PyTorch Community,

I am seeking some help/ideas on some training I want to pursue.

The idea is to train a VNET for a segmentation task and use its output for two tasks:

a) calculating the main loss (which could be DICE or whatever)
b) use it as the input to a pretrained AE model and calculating an additional loss

loss = mainLoss + additional loss

Pretrained model: AutoEncoder that was trained to learn the most prominent properties of the structure to segment.

Model to train: VNET

The idea is to have a pretrained model on a segmentation via autoencoder to preserve information about the shape of the structure to segment.

Do you have any tips on what to be aware of when doing that? Let me know if you have any questions.

Thank you,

How did you pretrain the Autoencoder?
Based on your explanation, it seems you’ve used the real images:

Would this approach work, if you now feed the segmentation output of your VNET to this AE?
The image statistics, ranges etc. should be quite different, so I’m not sure if it’ll work out of the box or if you would need to train both models end-to-end.

It’s an interesting approach and reminds me of Adversarial Learning for Semi-Supervised Semantic Segmentation, but with an AE instead of Discriminator.

Thanks for replying so quickly.

What I am trying to do is described in this paper: Anatomically Constrained Neural Networks (ACNNs): Application to Cardiac Image Enhancement and Segmentation

The AutoEncoder is pretrained with the ACNN input segmentations.

The idea of the paper is to use the predicted segmentation from the VNET (or whatever segmentation network you are using). The predicted segmentation and the ground-truth segmentation will be the input to 2 separate AE instances where the latent representations of both will be compared via Euclidean distance measure.

Does this make more sense now?

Thanks for your ideas/comments.