How to use torch.exp() with flatten()

Hi everyone,
I’ve trained a model using CNN like this:

model = nn.Sequential(
                      nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,padding_mode = 'zeros'),
                      nn.Flatten(),
                      nn.Linear(3701776,256),
                      nn.ReLU(),
                      nn.Dropout(0.2),
                      nn.Linear(256, 128),
                      nn.ReLU(),
                      nn.Dropout(0.2),
                      nn.Linear(128, 64),
                      nn.ReLU(),
                      nn.Dropout(0.2),
                      nn.Linear(64, 5),
                      nn.LogSoftmax(dim = 1)
                     )

I can train the model just fine calculating loss / accuracy like this:

    with torch.no_grad():
      # Set the model to evaluation mode
      model.eval()
      
      # Validation pass
      for images, labels in test_dataloader:
        
        log_ps = model(images)
       
        test_loss += criterion(log_ps, labels)
        
        ps = torch.exp(log_ps)
        top_p, top_class = ps.topk(1, dim = 1)
        equals = top_class == labels.view(*top_class.shape)
        accuracy += torch.mean(equals.type(torch.FloatTensor))

But I don’t know how to test samples, since when I try to call

ps = torch.exp(model(img))

I get an error:

mat1 and mat2 shapes cannot be multiplied (16x231361 and 3701776x256)

Probably since I’m not passing in batches, but I’m not sure how to fix this. Any help would be greatly appreciated

The torch.exp call should not fail as it’s an elementwise operation and thus won’t raise a shape mismatch. Check the input shapes to your model and make sure these are valid.

Thank you for your awnser,

The input shape during training is: torch.Size([64, 3, 485, 485]), 64 being the batch size and torch.exp(model(img)) works fine.

But when I try to test the model afterwards the input shape is: torch.Size([3, 485, 485]), which causes the error. I’m not sure how to fix this, since I don’t want to test a batch?

Add the missing batch dimension explicitly via model(x.unsqueeze(0)) and it should work.
The reason for the error is that the nn.Flatten will keep dim0 of the intermediate activation and will treat it as the batch dimension. However, when no batch dimension is passed to the model, dim0 will represent the channel dimension and the shape mismatch error will be raised.

2 Likes

That worked, thanks a lot!