Binary classifier issue not using semantic segmentation

I am getting an error, “RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [8]” when trying to train my binary classifier. I have noticed other posts that seem to address this same issue but I think mine is different in that the data I am trying to pass through the system is an image followed by a 0 or a 1, which represents the class of the photo. I was able to make this work using tensorflow but can’t seem to get past this error message. Any help would be greatly appreciated.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(3, 5, 1, padding=0),
            nn.Conv2d(5, 32, 1, padding = 0),
            nn.Conv2d(32, 2, 1, padding = 0),
            nn.ReLU(inplace=True)
        )
        
        self.block1 = self.conv_block(c_in=3, c_out=256, dropout=0.1, kernel_size=5, stride=1, padding=2)
        self.block2 = self.conv_block(c_in=256, c_out=128, dropout=0.1, kernel_size=3, stride=1, padding=1)
        self.block3 = self.conv_block(c_in=128, c_out=64, dropout=0.1, kernel_size=3, stride=1, padding=1)
        self.lastcnn = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=56, stride=1, padding=0)        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
            
    def forward(self, x):
        x = self.block1(x)
        x = self.maxpool(x)        
        x = self.block2(x)        
        x = self.block3(x)
        x = self.maxpool(x)        
        x = self.lastcnn(x)      

        return x
    
    def conv_block(self, c_in, c_out, dropout,  **kwargs):
        seq_block = nn.Sequential(
            nn.Conv2d(in_channels=c_in, out_channels=c_out, **kwargs),
            nn.BatchNorm2d(num_features=c_out),
            nn.ReLU(),
            nn.Dropout2d(p=dropout)
            #nn.Flatten()
        )        
        return seq_block
print('Begin training')

for e in tqdm(range(1, 9)):
    train_epoch_loss = 0
    train_epoch_acc = 0
    net.train()
    for X_train_batch, y_train_batch in train_loader:
        X_train_batch, y_train_batch = X_train_batch.to(device), y_train_batch.to(device)
        optimizer.zero_grad()
        
        y_train_pred = net.forward(X_train_batch).squeeze()
        train_loss = criterion(y_train_pred, y_train_batch)
        train_acc = binary_acc(y_train_pred, y_train_batch)
        
        train.loss.backward()
        optimizer.step()
        
        train_epoch_loss += train_loss.item()
        train_epoch_acc += train_acc.item()
        
    with torch.no_grad():
        model.eval()

I know that I’m getting this error message because the system is expecting an image to be passed through at the line for the criterion, but how would I edit this to just pass a 0 or a 1 that would represent the class?

@ptrblck I have seen you comment on some similar issues, would you mind giving my code a look to see if there are any issues you can point out? I’ve tried changing the loss function to BCEWithLogitsLoss(), tried adding squeeze() to the output, tried adding nn.Flatten() in my layers but just not getting anywhere with it. My image shape is torch.Size([8, 3, 256, 256]) and label shape is torch.Size([8])

The error is raised, since you are trying to use a scalar target (single class index for each sample) on an image-like model output, as you’ve already explained.
This won’t work out of the box and I don’t know, how TensorFlow manages this use case without a hidden reduction somewhere.

If you are working on a multi-class classification, which seems to be the case based on your target, you would have to provide class logits as the model output in the shape [batch_size, nb_classes], where each value in dim1 represents the logit for the corresponding class index.

Usually this is done by flattening the conv layer output and by passing it through a final classifier e.g. a linear layer.
If you don’t want to use a linear layer, you would have to make sure your last activation output from lastcnn returns a single pixel output in the shape [batch_size, nb_classes, 1, 1].
If your output already has this shape, you could use squeeze to provide the desired output via:

output = output.squeeze(3).squeeze(2)

and pass it to the criterion.
If this activation has more than a single pixel, you could either change the model architecture by increasing the stride etc. or add an adaptive pooling layer, which would output the 1x1 activation.

Awesome thanks for the help, I will try that now