top neural network outputs an image, which needs to be fed to two below neural networks, which feed image to below two neural networks, and then return probability back to the top, this implementation uses JoinProbs class to compute probability, but this probability needs to be done within Top, Base, Middle classes, when I attempt at calling forward method of Middle from Top, it does not give desired result.
class Top(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Dropout(0.5))
self.lin = nn.Linear(10, 10)
def forward(self, x):
return self.encoder(x), self.encoder(x)
net = Top()
outp1, outp2 = net(torch.randn(1, 1, 28, 28))
outp1.shape, outp2.shape
class Middle(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Dropout(0.5))
self.lin = nn.Linear(10, 10)
def forward(self, x):
return self.encoder(x), self.encoder(x)
net2 = Middle()
outp11, outp12, outp21, outp22 = *net2(outp1), *net2(outp2)
y = map(lambda x: x.shape, (outp11, outp12, outp21, outp22))
list(y)
class Base(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(784, 10)
def forward(self, x):
out = (x.view(x.size(0), -1))
return self.lin(out)
net3 = Base()
prob11, prob12, prob21, prob22 = net3(outp11), net3(outp12), net3(outp21), net3(outp22)
y = map(lambda x: x.shape, (prob11, prob12, prob21, prob22))
list(y)
class JoinProbs(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(20, 10)
def forward(self, *x):
out = torch.cat(x, dim=-1)
return self.lin(out)
net4 = JoinProbs()
a, b = net4(prob11, prob12), net4(prob21, prob22)
a, b
net4(a, b)