Dataparallel size does not match with target

I am using DataParallel to train my model on 2 GPUs. However, when calculating loss, it reports:

RuntimeError: The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0

Here is my model.py

import torch
import torch.nn as nn
import math
import torch.nn.init as init
import os

class _ResBLockDB(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(_ResBLockDB, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, stride, 1, bias=True)
        ])
        for i in self.modules():
            if isinstance(i, nn.Conv2d):
                j = i.kernel_size[0] * i.kernel_size[1] * i.out_channels
                i.weight.data.normal_(0, math.sqrt(2 / j))
                if i.bias is not None:
                    i.bias.data.zero_()

    def forward(self, x):
        out = self.layers(x)
        residual = x
        out = torch.add(residual, out)
        return out

class _ResBlockSR(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(_ResBlockSR, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, stride, 1, bias=True)
        ])
        for i in self.modules():
            if isinstance(i, nn.Conv2d):
                j = i.kernel_size[0] * i.kernel_size[1] * i.out_channels
                i.weight.data.normal_(0, math.sqrt(2 / j))
                if i.bias is not None:
                    i.bias.data.zero_()

    def forward(self, x):
        out = self.layers(x)
        residual = x
        out = torch.add(residual, out)
        return out

class _DeblurringMoudle(nn.Module):
    def __init__(self):
        super(_DeblurringMoudle, self).__init__()
        self.conv1     = nn.Conv2d(3, 64, (7, 7), 1, padding=3)
        self.relu      = nn.LeakyReLU(0.2, inplace=True)
        layers = []
        for i in range(0, 6):
            layers.append(_ResBLockDB(64, 64))
        self.resBlock1 =  nn.ModuleList(layers)
        # self.resBlock1 = self._makelayers(64, 64, 6)
        self.conv2 = nn.ModuleList([
            nn.Conv2d(64, 128, (3, 3), 2, 1),
            nn.ReLU(inplace=True)
        ])
        layers = []
        for i in range(0, 6):
            layers.append(_ResBLockDB(128, 128))
        self.resBlock2 = nn.ModuleList(layers)
        # self.resBlock2 = self._makelayers(128, 128, 6)
        self.conv3 = nn.ModuleList([
            nn.Conv2d(128, 256, (3, 3), 2, 1),
            nn.ReLU(inplace=True)
        ])
        layers = []
        for i in range(0, 6):
            layers.append(_ResBLockDB(256, 256))
        self.resBlock3 = nn.ModuleList(layers)
        # self.resBlock3 = self._makelayers(256, 256, 6)
        self.deconv1 = nn.ModuleList([
            nn.ConvTranspose2d(256, 128, (4, 4), 2, padding=1),
            nn.ReLU(inplace=True)
        ])
        self.deconv2 = nn.ModuleList([
            nn.ConvTranspose2d(128, 64, (4, 4), 2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, (7, 7), 1, padding=3)
        ])
        self.convout = nn.ModuleList([
            nn.Conv2d(64, 64, (3, 3), 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, (3, 3), 1, 1)
        ])
        for i in self.modules():
            if isinstance(i, nn.Conv2d):
                j = i.kernel_size[0] * i.kernel_size[1] * i.out_channels
                i.weight.data.normal_(0, math.sqrt(2 / j))
                if i.bias is not None:
                    i.bias.data.zero_()

    # def _makelayers(self, inchannel, outchannel, block_num, stride=1):
    #     layers = []
    #     for i in range(0, block_num):
    #         layers.append(_ResBLockDB(inchannel, outchannel))
    #     return nn.Sequential(*layers)

    def forward(self, x):
        con1   = self.relu(self.conv1(x))
        res1   = self.resBlock1(con1)
        res1   = torch.add(res1, con1)
        con2   = self.conv2(res1)
        res2   = self.resBlock2(con2)
        res2   = torch.add(res2, con2)
        con3   = self.conv3(res2)
        res3   = self.resBlock3(con3)
        res3   = torch.add(res3, con3)
        decon1 = self.deconv1(res3)
        deblur_feature = self.deconv2(decon1)
        deblur_out = self.convout(torch.add(deblur_feature, con1))
        return deblur_feature, deblur_out

class _SRMoudle(nn.Module):
    def __init__(self):
        super(_SRMoudle, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, (7, 7), 1, padding=3)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        layers = []
        for i in range(0, 8):
            layers.append(_ResBlockSR(64, 64))
        self.resBlock = nn.ModuleList(layers)
        # self.resBlock = self._makelayers(64, 64, 8, 1)
        self.conv2 = nn.Conv2d(64, 64, (3, 3), 1, 1)

        for i in self.modules():
            if isinstance(i, nn.Conv2d):
                j = i.kernel_size[0] * i.kernel_size[1] * i.out_channels
                i.weight.data.normal_(0, math.sqrt(2 / j))
                if i.bias is not None:
                    i.bias.data.zero_()

    # def _makelayers(self, inchannel, outchannel, block_num, stride=1):
    #     layers = []
    #     for i in range(0, block_num):
    #         layers.append(_ResBlockSR(inchannel, outchannel))
    #     return nn.Sequential(*layers)

    def forward(self, x):
        con1 = self.relu(self.conv1(x))
        res1 = self.resBlock(con1)
        con2 = self.conv2(res1)
        sr_feature = torch.add(con2, con1)
        return sr_feature

class _GateMoudle(nn.Module):
    def __init__(self):
        super(_GateMoudle, self).__init__()

        self.conv1 = nn.Conv2d(131,  64, (3, 3), 1, 1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(64, 64, (1, 1), 1, padding=0)

        for i in self.modules():
            if isinstance(i, nn.Conv2d):
                j = i.kernel_size[0] * i.kernel_size[1] * i.out_channels
                i.weight.data.normal_(0, math.sqrt(2 / j))
                if i.bias is not None:
                    i.bias.data.zero_()

    def forward(self, x):
        con1 = self.relu(self.conv1(x))
        scoremap = self.conv2(con1)
        return scoremap

class _ReconstructMoudle(nn.Module):
    def __init__(self):
        super(_ReconstructMoudle, self).__init__()
        layers = []
        for i in range(0, 8):
            layers.append(_ResBLockDB(64, 64))
        self.resBlock =  nn.ModuleList(layers)
        # self.resBlock = self._makelayers(64, 64, 8)
        self.conv1 = nn.Conv2d(64, 256, (3, 3), 1, 1)
        self.pixelShuffle1 = nn.PixelShuffle(2)
        self.relu1 = nn.LeakyReLU(0.1, inplace=True)
        self.conv2 = nn.Conv2d(64, 256, (3, 3), 1, 1)
        self.pixelShuffle2 = nn.PixelShuffle(2)
        self.relu2 = nn.LeakyReLU(0.2, inplace=True)
        self.conv3 = nn.Conv2d(64, 64, (3, 3), 1, 1)
        self.relu3 = nn.LeakyReLU(0.2, inplace=True)
        self.conv4 = nn.Conv2d(64, 3, (3, 3), 1, 1)

        for i in self.modules():
            if isinstance(i, nn.Conv2d):
                j = i.kernel_size[0] * i.kernel_size[1] * i.out_channels
                i.weight.data.normal_(0, math.sqrt(2 / j))
                if i.bias is not None:
                    i.bias.data.zero_()

    # def _makelayers(self, inchannel, outchannel, block_num, stride=1):
    #     layers = []
    #     for i in range(0, block_num):
    #         layers.append(_ResBLockDB(inchannel, outchannel))
    #     return nn.Sequential(*layers)

    def forward(self, x):
        res1 = self.resBlock(x)
        con1 = self.conv1(res1)
        pixelshuffle1 = self.relu1(self.pixelShuffle1(con1))
        con2 = self.conv2(pixelshuffle1)
        pixelshuffle2 = self.relu2(self.pixelShuffle2(con2))
        con3 = self.relu3(self.conv3(pixelshuffle2))
        sr_deblur = self.conv4(con3)
        return sr_deblur

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.deblurMoudle = _DeblurringMoudle()
        self.srMoudle = _SRMoudle()
        self.geteMoudle = _GateMoudle()
        self.reconstructMoudle = _ReconstructMoudle()
        # self.deblurMoudle      = self._make_net(_DeblurringMoudle)
        # self.srMoudle          = self._make_net(_SRMoudle)
        # self.geteMoudle        = self._make_net(_GateMoudle)
        # self.reconstructMoudle = self._make_net(_ReconstructMoudle)

    def forward(self, x, gated, isTest):
        if isTest == True:
            origin_size = x.size()
            input_size  = (math.ceil(origin_size[2]/4)*4, math.ceil(origin_size[3]/4)*4)
            out_size    = (origin_size[2]*4, origin_size[3]*4)
            x           = nn.functional.upsample(x, size=input_size, mode='bilinear')

        deblur_feature, deblur_out = self.deblurMoudle(x)
        sr_feature = self.srMoudle(x)
        if gated == True:
            scoremap = self.geteMoudle(torch.cat((deblur_feature, x, sr_feature), 1))
        else:
            scoremap = torch.cuda.FloatTensor().resize_(sr_feature.shape).zero_()+1
        repair_feature = torch.mul(scoremap, deblur_feature)
        fusion_feature = torch.add(sr_feature, repair_feature)
        recon_out = self.reconstructMoudle(fusion_feature)

        if isTest == True:
            recon_out = nn.functional.upsample(recon_out, size=out_size, mode='bilinear')

        return deblur_out, recon_out

It seems that I did not use any view(-1, xxx) function as in https://stackoverflow.com/questions/56719867/pytorch-expected-input-batch-size-12-to-match-target-batch-size-64 and in https://discuss.pytorch.org/t/valueerror-expected-input-batch-size-324-to-match-target-batch-size-4/24498/3
I am really confused. Can anyone help? Thanks in advance.

Hi can you try adding some debugging logs to check your tensor shapes.

Here is my training procedure:

def train(train_gen, model, criterion, optimizer, epoch, contrast_w=0):
    epoch_loss = 0
    for iteration, batch in enumerate(train_gen, 1):
        torch.cuda.empty_cache()
        #input, targetdeblur, targetsr
        LR_Blur = batch[0]
        LR_Deblur = batch[1]
        HR = batch[2]
        LR_Blur = LR_Blur.to(device)
        LR_Deblur = LR_Deblur.to(device)
        HR = HR.to(device)

        if opt.isTest == True:
            test_Tensor = torch.cuda.FloatTensor().resize_(1).zero_()+1
        else:
            test_Tensor = torch.cuda.FloatTensor().resize_(1).zero_()
        if opt.gated == True:
            gated_Tensor = torch.cuda.FloatTensor().resize_(1).zero_()+1
        else:
            gated_Tensor = torch.cuda.FloatTensor().resize_(1).zero_()

        print("Before", iteration, LR_Blur.shape, HR.shape, LR_Deblur.shape)
        [lr_deblur, sr] = model(LR_Blur, gated_Tensor, test_Tensor)
        print("after", lr_deblur.shape, sr.shape, LR_Blur.shape, gated_Tensor.shape, test_Tensor.shape)

        loss1 = criterion(lr_deblur, LR_Deblur)
        loss2 = criterion(sr, HR)
        mse = loss2 + opt.lambda_db * loss1

I print the shape of the data before and after entering the model.
Also, I print the shape at the beginning and the end in the forward function of Net.

def forward(self, x, gated, isTest):
        print("IN model", x.shape, gated.shape)
        
        ...... ...... ...... 

        print("[return]", deblur_out.shape, recon_out.shape)
        return deblur_out, recon_out

Here are the outputs:

Before 1 torch.Size([16, 3, 24, 24]) torch.Size([16, 3, 96, 96]) torch.Size([16, 3, 24, 24])
IN model torch.Size([8, 3, 24, 24]) torch.Size([1])
[return] torch.Size([8, 3, 24, 24]) torch.Size([8, 3, 96, 96])
after torch.Size([8, 3, 24, 24]) torch.Size([8, 3, 96, 96]) torch.Size([16, 3, 24, 24]) torch.Size([1]) torch.Size([1])

Here is the complete error message:

Traceback (most recent call last):
  File "train_GFN_4x.py", line 191, in <module>
    train(trainloader, model, criterion, optimizer, epoch, opt.contrast_w)
  File "train_GFN_4x.py", line 118, in train
    loss1 = criterion(lr_deblur, LR_Deblur)
  File "/usr/local/conda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/conda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 445, in forward
    return F.mse_loss(input, target, reduction=self.reduction)
  File "/usr/local/conda/envs/py36/lib/python3.6/site-packages/torch/nn/functional.py", line 2647, in mse_loss
    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
  File "/usr/local/conda/envs/py36/lib/python3.6/site-packages/torch/functional.py", line 65, in broadcast_tensors
    return _VF.broadcast_tensors(tensors)
RuntimeError: The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0

do you have the same error without data parallel wrapping?

also, where is the data parallel wrapping codes?

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if opt.resume:
    if os.path.isfile(opt.resume):
        print("Loading from checkpoint {}".format(opt.resume))
        model = torch.load(opt.resume)
        model.load_state_dict(model.state_dict())
        opt.start_training_step, opt.start_epoch = which_trainingstep_epoch(opt.resume)
else:
    model = Net()
    mkdir_steptraing(opt.model_name)

model = model.to(device)
model = nn.DataParallel(model, device_ids=range(opt.n_GPUs))

No. When running my code without DataParallel, I do not have this error. My code can execute correctly.

Also, if I pass a different batch size and run on a single GPU, it is still correct. So I think I did not implicitly specified a specific batch size in my code.

I try to find whether this error is cause by my model.py, so I changed my model to a very simple CNN network.
Here is my code:

import torch
import torch.nn as nn

class DemoModel(nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 7, 2, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, 2, padding=1)
        self.conv3 = nn.Conv2d(64, 3, 1, padding=1)

    def forward(self, x, not_used1, not_used2):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        return x, out
model = DemoModel()
model = model.to(device)
model = nn.DataParallel(model, device_ids=range(opt.n_GPUs))

However, it will also cause error using two GPUs.

RuntimeError: The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0

Will this error be related to h5 files? Because I read my data from a .h5 file.

I modified my dataset and read data from folders but it still failed. I think there is nothing wrong with my device because when using DataParallel in other projects, it worked well.

I finally found the problem. It is due to this line: test_Tensor = torch.cuda.FloatTensor().resize_(1).zero_()+1