RuntimeError: The size of tensor a (64) must match the size of tensor b (25) at non-singleton dimension 0

Hi,

I got this error, couldn’t reproduce this. Even Though, I checked all my model dimensions.

Can anyone tell me where the issue exactly?

Below is my code:
class SCNN(nn.Module):
def init(self, class_num=NUM_CHARS):
super(SCNN, self).init()
self.classes = class_num
# in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[0]
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
# in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[1]
self.avg_poo11 = nn.AvgPool2d(kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)
self.avg_poo12 = nn.AvgPool2d(kernel_size=1, stride=1, padding=0)

    self.fc1 = nn.Linear(6272, cfg_fc[0])
    self.fc2 = nn.Linear(cfg_fc[0], cfg_fc[1])

def forward(self, input, time_window=20):
    c1_mem = c1_spike = torch.zeros(args.batch_size, 32, 28, 28, device=DEVICE)
    c2_mem = c2_spike = torch.zeros(args.batch_size, 32, 14, 14, device=DEVICE)

    h1_mem = h1_spike = h1_sumspike = torch.zeros(args.batch_size, cfg_fc[0], device=DEVICE)
    h2_mem = h2_spike = h2_sumspike = torch.zeros(args.batch_size, NUM_CHARS, device=DEVICE)

    for step in range(time_window):  # simulation time steps
        x = input > torch.rand(input.size(), device=DEVICE)  # prob. firing
        # print(x.shape)

        c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)
        # print(c1_mem.shape)

        # x = nn.AvgPool2d(kernel_size=1, stride=1, padding=0)

        x = self.avg_poo11(c1_spike)
        # print(x.shape)

        c2_mem, c2_spike = mem_update(self.conv2, x, c2_mem, c2_spike)
        # print(c2_mem.shape)

        x = self.avg_poo12(c2_spike)
        # print(x.shape)
        x = x.view(args.batch_size, -1)
        # print(x.shape)

        h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
        # print(h1_mem.shape)
        h1_sumspike += h1_spike
        h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem, h2_spike)
        # print(h2_mem.shape)
        h2_sumspike += h2_spike

    outputs = h2_sumspike / time_window
    return outputs

names = ‘spiking_model’

best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
acc_record = list([])
loss_train_record = list([])
loss_test_record = list([])

snn = SCNN(class_num=NUM_CHARS)
snn.to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=args.learning_rate)

num_epochs = 50

for epoch in range(num_epochs):
running_loss = 0
start_time = time.time()
for i, (images, labels) in enumerate(dataloaders[“train”]):
snn.zero_grad()
optimizer.zero_grad()

    images = images.float().to(DEVICE)
    outputs = snn(images)
    labels_ = torch.zeros(args.batch_size, NUM_CHARS).scatter_(1, labels.view(-1, 1), 1)
    loss = criterion(outputs.cpu(), labels_)
    running_loss += loss.item()
    loss.backward()
    optimizer.step()
    if (i + 1) % 100 == 0:
        print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f'
              % (epoch + 1, num_epochs, i + 1, len(datasets['train']) // args.batch_size, running_loss))
        running_loss = 0
        print('Time elasped:', time.time() - start_time)
correct = 0
total = 0
optimizer = lr_scheduler(optimizer, epoch, args.learning_rate, 40)

with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(dataloaders['val']):
        inputs = inputs.to(DEVICE)
        optimizer.zero_grad()
        outputs = snn(inputs)
        labels_ = torch.zeros(args.batch_size, NUM_CHARS).scatter_(1, targets.view(-1, 1), 1)
        loss = criterion(outputs.cpu(), labels_)
        _, predicted = outputs.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())
        if batch_idx % 100 == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(dataloaders['val']), ' Acc: %.5f' % acc)

print('Iters:', epoch, '\n\n\n')
print('Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total))
acc = 100. * float(correct) / float(total)
acc_record.append(acc)
if epoch % 5 == 0:
    print(acc)
    print('Saving..')
    state = {
        'net': snn.state_dict(),
        'acc': acc,
        'epoch': epoch,
        'acc_record': acc_record,
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt' + names + '.t7')
    best_acc = acc

Could you post the complete stack trace, please?
Also, you can post code snippets by wrapping them into three backticks ```, which would make debugging easier :wink:

Thank you for your response.

I have solved my issue :grinning: