My model is as follows:
class Ensemble(nn.Module):
def __init__(self, model0, model1, model2, n_classes = 0, activation='relu'):
super(Ensemble, self).__init__()
self.model0 = model0
self.model1 = model1
self.model2 = model2
self.n_classes = n_classes
self.activation = Activation(activation)
self.up1 = Up(3072, 768)
self.up2 = Up(768, 192)
self.up3 = Up(192, 48)
self.up4 = Up(48, 12)
self.out_final = nn.Conv2d(12, n_classes, 1)
def forward(self, x1, x2, x3):
x1 = self.model0(x1)
x2 = self.model1(x2)
x3 = self.model2(x3)
x = torch.cat((x1, x2), dim=1)
x = torch.cat((x, x3), dim=1) #(End of the encoder)
out = self.up1(x)
out = self.up2(out)
out = self.up3(out)
out = self.up4(out)
out = self.out_final(out)
out = self.activation(out)
out = F.softmax(out, dim=1)
return out
How to define Dataset multiple image inputs and Masks for the above ensemble model for semantic segmentation task?