RuntimeError: The expanded size of the tensor (3) must match the existing size (16) at non-singleton dimension 0. Target sizes: [3, 15, 15]. Tensor sizes: [16, 15, 15]

why do I get this error, I pass in cifar10 images, [3, 32, 32] in size to this model

def conv_block(in_channels, out_channels, k):
    # set_trace()
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, k, padding=0),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

from IPython.core.debugger import set_trace

class Top(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = conv_block(3, 16, 3)
    self.lin = nn.Linear(20, 10)
    self.childone = Second()
    self.childtwo = Second()
  def forward(self, x):
    # set_trace()
    a = self.childone(self.encoder(x))
    b = self.childtwo(self.encoder(x))
    # print('top', a.shape, b.shape)
    out = torch.cat((a, b), dim=-1)
    return self.lin(out) 

class Second(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = conv_block(16, 32, 3)
    self.lin = nn.Linear(20, 10)
    self.childone = Middle()
    self.childtwo = Middle()

  def forward(self, x):
    a = self.childone(self.encoder(x))
    b = self.childtwo(self.encoder(x))
    # print('middle', a.shape, b.shape)
    out = torch.cat((a, b), dim=-1)
    return self.lin(out)

class Middle(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = conv_block(32, 64, 1)
    self.lin = nn.Linear(20, 10)
    self.childone = Bottom()
    self.childtwo = Bottom()

  def forward(self, x):
    a = self.childone(self.encoder(x))
    b = self.childtwo(self.encoder(x))
    # print('middle', a.shape, b.shape)
    out = torch.cat((a, b), dim=-1)
    return self.lin(out)

class Bottom(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = conv_block(64, 128, 1)
    self.lin_one = nn.Linear(128, 10)
  def forward(self, x):
    
    # print('bottom', x.shape)
    out = self.encoder(x)
    return (self.lin_one(out.view(out.size(0), -1)))

model = Top()
# inp = [None, train_dataset[0][0]]
model.to('cuda')

Hi,

I guess x is not the right shape at the beginning of the forward of Second?
Can you add prints there to check that?
Also do you have an exact stack trace where this error comes from?

this error comes when

for i, (data, target) in enumerate(train_loader):
        data, target = data.to('cuda'), target.to('cuda')
        optimizer.zero_grad()
        set_trace()
        output = model(data)

after model(data)
after using set_trace() in Top()

ipdb> n
> <ipython-input-4-c93b1c3acece>(82)forward()
     80   def forward(self, x):
     81     set_trace()
---> 82     a = self.childone(self.encoder(x))
     83     b = self.childtwo(self.encoder(x))
     84     # print('top', a.shape, b.shape)

ipdb> x.shape
torch.Size([100, 3, 32, 32])
ipdb> n
RuntimeError: The expanded size of the tensor (3) must match the existing size (16) at non-singleton dimension 0.  Target sizes: [3, 15, 15].  Tensor sizes: [16, 15, 15]
> <ipython-input-4-c93b1c3acece>(82)forward()
     80   def forward(self, x):
     81     set_trace()
---> 82     a = self.childone(self.encoder(x))
     83     b = self.childtwo(self.encoder(x))
     84     # print('top', a.shape, b.shape)

when I use

torchsummary.summary(model, (3, 32, 32), batch_size=100)

then it prints out the model, but I get error during training

But where is this error raised inside the execution of this line of Top()? Is it during Second? Or Middle? Or Bottom?

I used set_trace() in forward of every class, it does not call any set_trace() of other classes, only set_trace() of Top() is called, and gives the error as

RuntimeError: The expanded size of the tensor (3) must match the existing size (16) at non-singleton dimension 0.  Target sizes: [3, 15, 15].  Tensor sizes: [16, 15, 15]

When the error is raised, you should have a full traceback of where the error occurred no?

This is weird as at the top level, nothing has size 15x15

ok I got it, I used make_grid torchvision, by registering hook, that is why it was giving error, thanks.