Change forward path at certain probability

Hi,
Let’s say I have a Bi-autoenocder, each stream conducts a standard AE thing.

class BiAutoencoder(nn.Module):
    def __init__(self):
        super(BiAutoencoder, self).__init__()
        self.encoder1 = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
        self.decoder1 = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Tanh())
        self.encoder2 = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
        self.decoder2 = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Tanh())

    def forward(self, x1, x2):
        x1 = self.encoder1(x1)
        x1_hat = self.decoder1(x1)
        x2 = self.encoder2(x2)
        x2_hat = self.decoder2(x2) 
       return x1_hat, x2_hat

Now I want to set a certain probability while training that the outputs of encoders forward through the other decoder.
That is, like 50% chance, the forward pipeline goes normally.
But at 50% chance, the forward would be like this.

    def forward(self, x1, x2):
        x1 = self.encoder1(x1)
        x2_hat = self.decoder2(x1)
        x2 = self.encoder2(x2)
        x1_hat = self.decoder1(x2) 
       return x1_hat, x2_hat

Any suggestion?Thanks.

I think this should work (I don’t tested):

  import random
  def forward(self, x1, x2):
        path = random.choice([True, False])
        if path:
            x1 = self.encoder1(x1)
            x2_hat = self.decoder2(x1)
            return x2_hat
        else:
            x2 = self.encoder2(x2)
            x1_hat = self.decoder1(x2) 
            return x1_hat