How do I create inter-class dependencies?


I am currently trying to improve the segmentation of an encoder-decoder network used to segment railroads.
The goal is to have a network output a segmentation for the treckbed/trainrails that the train is driving on (aka ego track) be as as good (high recall/precision) as possible.
For most scenes, which are mostly just one rail with no merges or forks, the network works really well, but performs poorly in scenes with merges and forks in the railroad.

Example mask:
(Yellow = rails the train drives one: red, green neighbor rails)

A theory I came up with is to classify the rails and the trackbed separately and use the information of the rails to improve the trackbed segmentation.
This should lead to less false positives due to some areas not being classified as trackbed because its not enclosed by 2 rails.
Also forks and merges have a specific rail layout as seen in the example image. If the ego track was to continue to the left, the red rail on the right within the yellow area would be connected to the yellow rail on the right. Since its not connected the ego track does not fork to the left and continues forward.

The only idea I’ve had is to use multi-label masks adding a class which combines the rails and trackbed as one class on top of the existing classes and segment the parts with both a high probability on the rails or trackbed class and the combined class.
I have not tested this yet and im not even sure if this would actually work.

I cant find any information on how to create a dependency of one class to another.
I could improve the segmentation with some post processing, but this is only a last resort and I would much rather solve this via the NN.

Does anybody have an idea how to attempt to solve this problem?