I want to ensemble Mode1A and Model1B, But there is a run time error

**Expected 4-dimensional input for 4-dimensional weight 8 3, but got 2-dimensional input of size [1, 25088] instead**

Please help me

```
class MyModelA(nn.Module):
def __init__(self):
super(MyModelA, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3,stride=1, padding=1),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2))
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2))
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2))
self.fc = nn.Linear(25088, 2)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
#out = self.layer4(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
class MyModelB(nn.Module):
..
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
#self.classifier = nn.Linear(4, 2)
def forward(self, x1, x2):
head1a, head1b = self.modelA(x1)
head2 = self.modelB(head1a)
x = torch.cat((head1b, head2), dim=1)
return x
# Create models and load state_dicts
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load('checkpoint1.pt'))
modelB.load_state_dict(torch.load('checkpoint2.pt'))
model = MyEnsemble(modelA, modelB)
x1, x2 = torch.randn(1,25088), torch.randn(1, 25088)
output = model(x1, x2)
```