How to have two Unets with a shared encoder?

Hello, I’m trying to have two Unet models that share an encoder and train them simultaneously, but I don’t know how to so. I was wondering if I could get some help here. Thanks!

Define one encoder and two decoders. Then for each of the decoders you call the same encoder.

Thank you @JuanFMontesinos . But, how can I do the skip connections that are in Unet?

Well, the same way you do it for the UNet…
You run the encoder, get the latent feats + their skip connections and run de decoder, then with the same latent + skip connections run the other decoder.

Ofc this is simpler or harder to do depending on how the UNet u are using is coded and modularized.

edit:
for clarifiying:

Unet2Dec
__init__
self.encoder = encoder()
self.decoder1=decoder()
self.decoder2=decoder()
forward(...):
latent,skip = self.encoder(input)
o1 = self.decoder(laten,skip)
o2 = self.decoder(latent,skip)

Thanks @JuanFMontesinos. One more quesiton. Can I train each of these decoders on a different dataset? I.e. decoder-1 gets trained on trainset-1 and decoder-2 gets trained on trainset-2.

Hmmm that is a strange question… the problem is not how to train the decoders but the encoder.

If you pass a sample from dataset through 1 Encoder1 Decoder 1 and then backpropagate, this will affect both e1 and d1.
If you pass a sample from ds2 through e1 d2 and then backpropagate this will affect e1 and d2.

So the encoder will be affected by both datasets while each decoder will be affected by a single dataset.

1 Like