Getting predictions only for the first channel in UNet architecture

Hi, I am trying multiclass segmentation using UNet. My model class looks as below

class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, X):
        return self.conv_block(X)




class EncoderBlock(nn.Module):

    def __init__(self, in_channels, **kwargs):
        super().__init__()
        self.conv_block = ConvBlock(in_channels, in_channels * 2, **kwargs)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, X):
        skip = self.conv_block(X)
        out = self.pool(skip)
        return out, skip




class DecoderBlock(nn.Module):

    def __init__(self, in_channels, **kwargs):
        super().__init__()
        self.up_conv = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv_block = ConvBlock(in_channels, in_channels // 2, **kwargs)

    def forward(self, X, skip):
        out = self.up_conv(X)
        out = torch.cat((skip, out), dim = 1)
        return self.conv_block(out)




class UNet(nn.Module):

    def __init__(self, in_channels: int, encoder_in_channel_list: list[int], decoder_in_channel_list: list[int], out_channels: int, **kwargs):
        super().__init__()

        # First convolution layer that takes the input of 3 channels from the original image
        self.in_conv = nn.Sequential(
            nn.Conv2d(in_channels, encoder_in_channel_list[0], **kwargs),
            nn.BatchNorm2d(encoder_in_channel_list[0]),
            nn.ReLU(),
            nn.Conv2d(encoder_in_channel_list[0], encoder_in_channel_list[0], **kwargs),
            nn.BatchNorm2d(encoder_in_channel_list[0]),
            nn.ReLU()
        )
        self.pool_after_in_conv = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        self.encoder = nn.ModuleList()
        for n_channels in encoder_in_channel_list: # 64, 128, 256
            self.encoder.append(EncoderBlock(n_channels, **kwargs))

        # Bottleneck
        self.bottleneck = ConvBlock(
            encoder_in_channel_list[-1] * 2,
            decoder_in_channel_list[0],
            **kwargs
        )

        # Decoder
        self.decoder = nn.ModuleList()
        for n_channels in decoder_in_channel_list: # 1024, 512, 256, 128
            self.decoder.append(DecoderBlock(n_channels, **kwargs))


        # Last convolution layer that output the predicted mask consisting n channels where n is
        # the number of classes [+ 1 (for background), for multiclass segmentation]
        self.out_conv = nn.Conv2d(decoder_in_channel_list[-1] // 2, out_channels, kernel_size=1)

    def forward(self, X):
        skip = []
        out = self.in_conv(X)
        skip.append(out)
        out = self.pool_after_in_conv(out)
        for encoder_block in self.encoder:
            out, s = encoder_block(out)
            skip.append(s)
        out = self.bottleneck(out)
        for decoder_block in self.decoder:
            s = skip.pop()
            out = decoder_block(out, s)
            out.shape
        out = self.out_conv(out)
        return out


encoder_in_channel_list = [64, 128, 256]
decoder_in_channel_list = [1024, 512, 256, 128]

model = UNet(3, encoder_in_channel_list, decoder_in_channel_list, 24, kernel_size=3, padding='same', bias=False).to(device)

The training loop for the model is

def train(
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        loss_fn,
        accuracy_fn,
        device: torch.device,
        dataloader: torch.utils.data.DataLoader
):
    agg_loss, agg_accuracy = 0, 0
    model.train()
    for X, y in dataloader:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        X, y = X.to(device), y.to(device).squeeze().long() # Remove the channel dimension, [N, C, H, W] --> [N, H, W], as this is required by Cross Entropy Loss
        optimizer.zero_grad()
        pred = model(X) # [N, C, H, W]
        loss = loss_fn(pred, y)
        agg_loss += loss
        accuracy = accuracy_fn(pred, one_hot(y, num_classes = 24).permute([0, 3, 1, 2])) # [N, H, W, C] --> [N, C, H, W]
        agg_accuracy += accuracy
        loss.backward()
        del pred, loss
        optimizer.step()

    num_batches = len(dataloader)
    return (agg_loss / num_batches).item(), (agg_accuracy / num_batches).item()


def validate(
        model: torch.nn.Module,
        loss_fn,
        accuracy_fn,
        device: torch.device,
        dataloader: torch.utils.data.DataLoader
):
    agg_loss, agg_accuracy = 0, 0
    model.eval()
    with torch.inference_mode():
        for X, y in dataloader:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            X, y = X.to(device), y.to(device).squeeze().long() # Remove the channel dimension, [N, C, H, W] --> [N, H, W], as this is required by Cross Entropy Loss
            pred = model(X) # [N, C, H, W]
            loss = loss_fn(pred, y)
            accuracy = accuracy_fn(pred, one_hot(y, num_classes = 24).permute([0, 3, 1, 2])) # [N, H, W, C] --> [N, C, H, W]
            agg_loss += loss
            agg_accuracy += accuracy
            del pred, loss
    num_batches = len(dataloader)
    return (agg_loss / num_batches).item(), (agg_accuracy / num_batches).item()


def run(
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        loss_fn,
        accuracy_fn,
        device: torch.device,
        train_dataloader: torch.utils.data.DataLoader,
        val_dataloader: torch.utils.data.DataLoader,
        epochs: int,
        early_stop: EarlyStop,
        verbose_after_every_n_epoch: int = 5
):
    t_loss, t_accuracy = [], []
    v_loss, v_accuracy = [], []
    for epoch in tqdm(range(1, epochs + 1)):
        loss, accuracy = train(
            model,
            optimizer,
            loss_fn,
            accuracy_fn,
            device,
            train_dataloader
        )
        t_loss.append(loss)
        t_accuracy.append(accuracy)

        loss, accuracy = validate(
            model,
            loss_fn,
            accuracy_fn,
            device,
            val_dataloader
        )
        v_loss.append(loss)
        v_accuracy.append(accuracy)

        if epoch % verbose_after_every_n_epoch == 0:
            print(f"Epoch {epoch}\n------------")
            print(f"Train Loss: {t_loss[-1]:.2f}\tValidation Loss: {v_loss[-1]:.2f}")
            print(f"Train Accuracy: {t_accuracy[-1]:.2f}\tValidaton Accuracy: {v_accuracy[-1]:.2f}\n\n")
        early_stop(loss, model)
        if early_stop.stop:
            break
    return t_loss, v_loss, t_accuracy, v_accuracy

I am using nn.CrossEntropyLoss() as my loss_fn and accuracy_fn is torchmetric.segmentation.DiceScore(num_classes=24, include_background=True, input_format='one-hot', average='macro')

The prediction made by the model has logits for all 4 channel but when I am trying argmax(dim=0) I am getting only zeros, that means the logits for the first channel is highest amongst all. I do not understand what is the issue here. Please help!

Is this a double-post from here or why did you not continue the discussion in the previous thread?

Assuming your output contains a batch dimension in dim0, argmax(dim=0) would be wrong since you are reducing the batch dimension instead of the channel dimension.

I was trying to make prediction for only a single image.

img = Image.open('test.jpg')
img = transforms(img)  # [C, H, W]
img = img.unsqueeze(dim=0) # [N, C, H, W]
model.eval()
with torch.inference_mode():
   pred = model(img) # [N, C, H, W]
   pred = pred.squeeze(dim=0) # [N, C, H, W] --> [C, H, W]
   pred = pred.argmax(dim=0) # [C, H, W] --> [H, W]

This pred contains all zeros.