ValueError: Expected input batch_size (128) to match target batch_size (32)

I’m getting the error ValueError: Expected input batch_size (128) to match target batch_size (32). for the following code:

class Custom(nn.Module):
    
    def __init__(self):
        super(ParallelNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 7)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 7)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 5)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 254, 5)
        self.bn4 = nn.BatchNorm2d(254)
        self.conv5 = nn.Conv2d(254, 512, 3)
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512, 32, 3)
        self.bn6 = nn.BatchNorm2d(32)
        
        self.conv1x1 = nn.Conv2d(32, 32, 1)
        
        self.fc = nn.Linear(32 * 26 * 26, 10) 

    def forward(self, x):  
        
        p = torch.chunk(x, 2, 3)
        
        xT = torch.chunk(p[0], 2, 2)
        xB = torch.chunk(p[1], 2, 2)
        
        x1 = F.relu(self.bn1(self.conv1(xT[0])))
        x1 = F.relu(self.bn2(self.conv2(x1)))       
        x1 = F.relu(self.bn3(self.conv3(x1)))
        x1 = F.relu(self.bn4(self.conv4(x1)))       
        x1 = F.relu(self.bn5(self.conv5(x1)))
        x1 = F.relu(self.bn6(self.conv6(x1)))
        
               
        x2 = F.relu(self.bn1(self.conv1(xT[1])))
        x2 = F.relu(self.bn2(self.conv2(x2)))
        x2 = F.relu(self.bn3(self.conv3(x2)))
        x2 = F.relu(self.bn4(self.conv4(x2)))      
        x2 = F.relu(self.bn5(self.conv5(x2)))
        x2 = F.relu(self.bn6(self.conv6(x2)))
        
        
        x3 = F.relu(self.bn1(self.conv1(xB[0])))
        x3 = F.relu(self.bn2(self.conv2(x3)))        
        x3 = F.relu(self.bn3(self.conv3(x3)))
        x3 = F.relu(self.bn4(self.conv4(x3)))        
        x3 = F.relu(self.bn5(self.conv5(x3)))
        x3 = F.relu(self.bn6(self.conv6(x3)))
        
            
        x4 = F.relu(self.bn1(self.conv1(xB[1])))
        x4 = F.relu(self.bn2(self.conv2(x4)))        
        x4 = F.relu(self.bn3(self.conv3(x4)))
        x4 = F.relu(self.bn4(self.conv4(x4)))        
        x4 = F.relu(self.bn5(self.conv5(x4)))
        x4 = F.relu(self.bn6(self.conv6(x4)))
         
        out = torch.cat((x1, x2, x3, x4))
        
        out = F.relu(self.conv1x1(out))
        
        out = out.view(out.size(0), -1)
       
        out = F.softmax(self.fc(out), dim=1)
        
        return out
    
model = Custom()

optimizer = optim.Adam(model.parameters(), lr=0.003)

epochs = 5

for e in range(epochs):
    
    train_loss = 0
    accuracy = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        log_ps = model(images)
        
        loss = nn.CrossEntropyLoss()(log_ps, labels)
        equality = (labels.data == ps.max(dim=1)[1])
        accuracy += equality.type(torch.FloatTensor).mean()
        
        loss.backward()
        optimizer.step()
         
        train_loss += loss.item()
        
    else:
        print("Training loss:", {running_loss})

The shape of this part out = F.relu(self.conv1x1(out)) is torch.Size([8, 32, 26, 26])

Can anyone help me, please? Thanks

I’m currently not seeing, where the batch sizes 32 and 128 are coming from, but I guess the torch.cat call yields the error, since it concatenates the tensors in dim0 by default.
Try to use out = torch.cat((x1, x2, x3, x4), 1) instead.