RuntimeError: Expected 4-dimensional input for 4-dimensional weight

I have a network, in which there are 3 architectures that share the same classifier.

class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels,batch_norm=False):

        super(VGGBlock,self).__init__()

        conv2_params = {'kernel_size': (3, 3),
                        'stride'     : (1, 1),
                        'padding'   : 1
                        }

        noop = lambda x : x

        self._batch_norm = batch_norm

        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
        self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
        self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    @property
    def batch_norm(self):
        return self._batch_norm

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.max_pooling(x)

        return x
class VGG16(nn.Module):

  def __init__(self, input_size, num_classes=1,batch_norm=False):
    super(VGG16, self).__init__()

    self.in_channels,self.in_width,self.in_height = input_size

    self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
    self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
    self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
    self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)

  @property
  def input_size(self):
      return self.in_channels,self.in_width,self.in_height

  def forward(self, x):

    x = self.block_1(x)
    x = self.block_2(x)
    x = self.block_3(x)
    x = self.block_4(x)

    return x
class VGG16Classifier(nn.Module):

  def __init__(self, num_classes=1,classifier = None,batch_norm=False):
    super(VGG16Classifier, self).__init__()


    self._vgg_a = VGG16((1,32,32),batch_norm=True)
    self._vgg_b = VGG16((1,32,32),batch_norm=True)
    self._vgg_star = VGG16((1,32,32),batch_norm=True)
    self.classifier = classifier

    if (self.classifier is None):
        self.classifier = nn.Sequential(
          nn.Linear(2048, 2048),
          nn.ReLU(True),
          nn.Dropout(p=0.5),
          nn.Linear(2048, 512),
          nn.ReLU(True),
          nn.Dropout(p=0.5),
          nn.Linear(512, num_classes)
        )

  def forward(self, x1,x2,x3):
      op1 = self._vgg_a(x1)
      op1 = torch.flatten(op1,1)
      op2 = self._vgg_b(x2)
      op2 = torch.flatten(op2,1)
      op3 = self._vgg_star(x3) 
      op3 = torch.flatten(op3,1)
      
      x1 = self.classifier(op1)
      x2 = self.classifier(op2)
      x3 = self.classifier(op3)

      return x1,x2,x3
model1 = VGG16((1,32,32),batch_norm=True)
model2 = VGG16((1,32,32),batch_norm=True)
model_star = VGG16((1,32,32),batch_norm=True)
model_combo = VGG16Classifier(model1,model2,model_star)

I want to traing model_combo using the following loss function:

class CombinedLoss(nn.Module):
    def __init__(self, loss_a, loss_b, loss_star, _lambda=1.0):
        super().__init__()
        self.loss_a = loss_a
        self.loss_b = loss_b
        self.loss_star = loss_star

        self.register_buffer('_lambda',torch.tensor(float(_lambda),dtype=torch.float32))


    def forward(self,y_hat,y):

        return (self.loss_a(y_hat[0],y[0]) + 
                self.loss_b(y_hat[1],y[1]) + 
                self.loss_combo(y_hat[2],y[2]) + 
                self._lambda * torch.sum(model_star.weight - torch.pow(torch.cdist(model1.weight+model2.weight), 2)))

In the training function I pass loaders, that for simplicity are loaders_a, loaders_b and again loaders_a, where loaders_a is related to the first 50% of data of MNIST and loaders_b to the latter 50% of MNIST.

def train(net, loaders, optimizer, criterion, epochs=20, dev=None, save_param=False, model_name="valerio"):
      loaders_a, loaders_b, loaders_star = loaders
    # try:
      net = net.to(dev)
      #print(net)
      #summary(net,[(net.in_channels,net.in_width,net.in_height)]*2)


      criterion.to(dev)


      # Initialize history
      history_loss = {"train": [], "val": [], "test": []}
      history_accuracy_a = {"train": [], "val": [], "test": []}
      history_accuracy_b = {"train": [], "val": [], "test": []}
      history_accuracy_star = {"train": [], "val": [], "test": []}
      # Store the best val accuracy
      best_val_accuracy = 0

      # Process each epoch
      for epoch in range(epochs):
        # Initialize epoch variables
        sum_loss = {"train": 0, "val": 0, "test": 0}
        sum_accuracy_a = {"train": 0, "val": 0, "test": 0}
        sum_accuracy_b = {"train": 0, "val": 0, "test": 0}
        sum_accuracy_star = {"train": 0, "val": 0, "test": 0}

        progbar = None
        # Process each split
        for split in ["train", "val", "test"]:
          if split == "train":
            net.train()
            #widgets = [
              #' [', pb.Timer(), '] ',
              #pb.Bar(),
              #' [', pb.ETA(), '] ', pb.Variable('ta','[Train Acc: {formatted_value}]')]

            #progbar = pb.ProgressBar(max_value=len(loaders_a[split]),widgets=widgets,redirect_stdout=True)

          else:
            net.eval()
          # Process each batch
          for j, ((input_a, labels_a), (input_b, labels_b), (input_s, labels_s)) in enumerate(zip(loaders_a[split], loaders_b[split], loaders_star[split])):
            labels_a = labels_a.unsqueeze(1).float()
            labels_b = labels_b.unsqueeze(1).float()
            labels_s = labels_s.unsqueeze(1).float()

            input_a = input_a.to(dev)
            labels_a = labels_a.to(dev)
            input_b = input_b.to(dev)
            labels_b = labels_b.to(dev)
            input_s = input_s.to(dev)
            labels_s = labels_s.to(dev)

            # Reset gradients
            optimizer.zero_grad()
            # Compute output
            pred = net(input_a,input_b, input_s)

            loss = criterion(pred, [labels_a, labels_b, labels_s])
            # Update loss
            sum_loss[split] += loss.item()
            # Check parameter update
            if split == "train":
              # Compute gradients
              loss.backward()
              # Optimize
              optimizer.step()

            # Compute accuracy
            pred_labels = (pred[2] >= 0.0).long()  # Binarize predictions to 0 and 1
            pred_labels_a = (pred[0] >= 0.0).long()  # Binarize predictions to 0 and 1
            pred_labels_b = (pred[1] >= 0.0).long()  # Binarize predictions to 0 and 1


            batch_accuracy_star = (pred_labels == labels_s).sum().item() / len(labels_s)
            batch_accuracy_a = (pred_labels_a == labels_a).sum().item() / len(labels_a)
            batch_accuracy_b = (pred_labels_b == labels_b).sum().item() / len(labels_b)
            # Update accuracy
            sum_accuracy_star[split] += batch_accuracy_star
            sum_accuracy_a[split] += batch_accuracy_a
            sum_accuracy_b[split] += batch_accuracy_b

            #if (split=='train'):
              #progbar.update(j, ta=batch_accuracy)
              #progbar.update(j, ta=batch_accuracy_a)
              #progbar.update(j, ta=batch_accuracy_b)

        #if (progbar is not None):
          #progbar.finish()
        # Compute epoch loss/accuracy
        #for split in ["train", "val", "test"]:
          #epoch_loss = sum_loss[split] / (len(loaders_a[split])+len(loaders_b[split])) 
          #epoch_accuracy_combo = {split: sum_accuracy_combo[split] / len(loaders[split]) for split in ["train", "val", "test"]}
          #epoch_accuracy_a = sum_accuracy_a[split] / len(loaders_a[split])
          #epoch_accuracy_b = sum_accuracy_b[split] / len(loaders_b[split])
        epoch_loss = sum_loss["train"] / (len(loaders_a["train"])+len(loaders_b["train"])+len(loaders_s["train"])) 
        epoch_accuracy_a = sum_accuracy_a["train"] / len(loaders_a["train"])
        epoch_accuracy_b = sum_accuracy_b["train"] / len(loaders_b["train"])
        epoch_accuracy_star = sum_accuracy_star["train"] / len(loaders_s["train"]) 

        epoch_loss_val = sum_loss["val"] / (len(loaders_a["val"])+len(loaders_b["val"])+len(loaders_s["val"])) 
        epoch_accuracy_a_val = sum_accuracy_a["val"] / len(loaders_a["val"])
        epoch_accuracy_b_val = sum_accuracy_b["val"] / len(loaders_b["val"])
        epoch_accuracy_star_val = sum_accuracy_star["val"] / len(loaders_s["val"]) 

        epoch_loss_test = sum_loss["test"] / (len(loaders_a["test"])+len(loaders_b["test"])+len(loaders_s["test"])) 
        epoch_accuracy_a_test = sum_accuracy_a["test"] / len(loaders_a["test"])
        epoch_accuracy_b_test = sum_accuracy_b["test"] / len(loaders_b["test"])
        epoch_accuracy_star_test = sum_accuracy_star["test"] / len(loaders_s["test"]) 


        # Store params at the best validation accuracy
        if save_param and epoch_accuracy["val"] > best_val_accuracy:
          # torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
          torch.save(net.state_dict(), f"{model_name}_best_val.pth")
          best_val_accuracy = epoch_accuracy["val"]

        # Update history
        for split in ["train", "val", "test"]:
          history_loss[split].append(epoch_loss)
          history_accuracy_a[split].append(epoch_accuracy_a)
          history_accuracy_b[split].append(epoch_accuracy_b)
          history_accuracy_star[split].append(epoch_accuracy_star)
        # Print info
        print(f"Epoch {epoch + 1}:",
              f"Training Loss = {epoch_loss:.4f},",)
        print(f"Epoch {epoch + 1}:",
              f"Training Accuracy for A = {epoch_accuracy_a:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Training Accuracy for B = {epoch_accuracy_b:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Training Accuracy for star = {epoch_accuracy_star:.4f},")
        
        print(f"Epoch {epoch + 1}:",
              f"Val Loss = {epoch_loss_val:.4f},",)
        print(f"Epoch {epoch + 1}:",
              f"Val Accuracy for A = {epoch_accuracy_a_val:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Val Accuracy for B = {epoch_accuracy_b_val:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Val Accuracy for star = {epoch_accuracy_star_val:.4f},")
        
        print(f"Epoch {epoch + 1}:",
              f"Test Loss = {epoch_loss_test:.4f},",)
        print(f"Epoch {epoch + 1}:",
              f"Test Accuracy for A = {epoch_accuracy_a_test:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Test Accuracy for B = {epoch_accuracy_b_test:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Test Accuracy for star = {epoch_accuracy_star_test:.4f},")
        print("\n")

But I got this error:

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 1, 3, 3], but got 2-dimensional input of size [128, 2048] instead

Hi Bruno!

You are passing a tensor of the wrong shape into the first Conv2d
layer in one of your VGGBlocks, most likely because you are passing
an input batch of 1-d vectors, rather than “3-d” images to your model.

Conv2d requires a shape of [nBatch, in_channels, height, width]
for its input (where nBatch can be arbitrary, in_channels matches
the in_channels of the Conv2d, and height and width are at least
as large kernel_size). You need the in_channels dimension, even
if in_channels = 1.

I’m guessing that you have a batch size of 128, and that your input
images have been flattened into 1-d vectors of length 2048.

The following example illustrates a valid input and then reproduces
your error with invalid input:

>>> import torch
>>> torch.__version__
'1.7.1'
>>> conv = torch.nn.Conv2d (1, 64, (3, 3))
>>> x_good = torch.randn (128, 1, 32, 64)
>>> x_bad = torch.randn (128, 32 * 64)
>>> conv (x_good).shape
torch.Size([128, 64, 30, 62])
>>> conv (x_bad).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 423, in forward
    return self._conv_forward(input, self.weight)
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 419, in _conv_forward
    return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 1, 3, 3], but got 2-dimensional input of size [128, 2048] instead

Best.

K. Frank

I edited the original post: I have flatten in the forward, just before passing to the classifier. I think it is correct, but it does not work…