Passing output image from top neural network to below, probabilities from below to top

top neural network outputs an image, which needs to be fed to two below neural networks, which feed image to below two neural networks, and then return probability back to the top, this implementation uses JoinProbs class to compute probability, but this probability needs to be done within Top, Base, Middle classes, when I attempt at calling forward method of Middle from Top, it does not give desired result.

class Top(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Dropout(0.5))
    self.lin = nn.Linear(10, 10)
  def forward(self, x):
    return self.encoder(x), self.encoder(x)

net = Top()

outp1, outp2 = net(torch.randn(1, 1, 28, 28))

outp1.shape, outp2.shape

class Middle(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Dropout(0.5))
    self.lin = nn.Linear(10, 10)
  def forward(self, x):
    return self.encoder(x), self.encoder(x)

net2 = Middle()

outp11, outp12, outp21, outp22 = *net2(outp1), *net2(outp2)

y = map(lambda x: x.shape, (outp11, outp12, outp21, outp22))

list(y)

class Base(nn.Module):
  def __init__(self):
    super().__init__()
    self.lin = nn.Linear(784, 10)
  def forward(self, x):
    out = (x.view(x.size(0), -1))
    return self.lin(out)

net3 = Base()

prob11, prob12, prob21, prob22 = net3(outp11), net3(outp12), net3(outp21), net3(outp22)

y = map(lambda x: x.shape, (prob11, prob12, prob21, prob22))

list(y)

class JoinProbs(nn.Module):
  def __init__(self):
    super().__init__()
    self.lin = nn.Linear(20, 10)
  def forward(self, *x):
    out = torch.cat(x, dim=-1)
    return self.lin(out)

net4 = JoinProbs()

a, b = net4(prob11, prob12), net4(prob21, prob22)

a, b

net4(a, b)

I modified it to use ‘up’ or ‘down’ direction

from IPython.core.debugger import set_trace

class Top(nn.Module):
  def __init__(self, direction):
    super().__init__()
    self.encoder = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Dropout(0.5))
    self.lin = nn.Linear(10, 10)
    self.direction = direction
    if self.direction == 'down':
      self.childone = Bottom()
      self.childtwo = Bottom()
  def forward(self, x):
    if self.direction == 'down':
      return self.childone(self.encoder(x)), self.childtwo(self.encoder(x))
    elif self.direction == 'up':
      return self.lin(x)

class Bottom(nn.Module):
  def __init__(self):
    super().__init__()
    self.lin_one = nn.Linear(784, 100)
    self.lin_two = nn.Linear(100, 10)
    self.parent = Top('up')
  def forward(self, x):
    return self.parent(self.lin_two(self.lin_one(x.view(x.size(0), -1))))

but this also has problem, as it gives output to Top class twice, so finally I get tuple instead of one output.