How to construct the conditional decoder for each input class?

Hello, everyone.

I have a question about, how to construct decoder for different class, and how to train this architecture.

Here is my desired network architecture (the first of the image). The purpose of this architecture is to generate better prediction of each class, using this conditional structure.

Input consists of sequential data, and it get through FC layer.
And then, by input class (Input can has three classes, {A, B or C}), the output of FC layer gets into a decoder. (If input class is A, then it goes through decoder A).
Should I make each nn.Module for each encoder and decoder, separately?

Currently, I am trying to construct this architecture in one single nn.Module like the second of the image. Mini-batch consists of samples with various classes.
After FC layer, split batch for each class.
Then, get into each decoder, then concatenate all output of decoders as final output.
Using this final output, calculate loss and back-propagation.
Is it a correct way to train them?

I am wondering that total loss is back propagated each decoder, which means A decoder can be back-propagated by B decoder prediction loss.

This approach looks correct for training.
However, how would this method work for inference (or validation and testing), if no labels are available?