Hi, I am trying multiclass segmentation using UNet. My model class looks as below
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, **kwargs),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, **kwargs),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, X):
return self.conv_block(X)
class EncoderBlock(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.conv_block = ConvBlock(in_channels, in_channels * 2, **kwargs)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, X):
skip = self.conv_block(X)
out = self.pool(skip)
return out, skip
class DecoderBlock(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.up_conv = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv_block = ConvBlock(in_channels, in_channels // 2, **kwargs)
def forward(self, X, skip):
out = self.up_conv(X)
out = torch.cat((skip, out), dim = 1)
return self.conv_block(out)
class UNet(nn.Module):
def __init__(self, in_channels: int, encoder_in_channel_list: list[int], decoder_in_channel_list: list[int], out_channels: int, **kwargs):
super().__init__()
# First convolution layer that takes the input of 3 channels from the original image
self.in_conv = nn.Sequential(
nn.Conv2d(in_channels, encoder_in_channel_list[0], **kwargs),
nn.BatchNorm2d(encoder_in_channel_list[0]),
nn.ReLU(),
nn.Conv2d(encoder_in_channel_list[0], encoder_in_channel_list[0], **kwargs),
nn.BatchNorm2d(encoder_in_channel_list[0]),
nn.ReLU()
)
self.pool_after_in_conv = nn.MaxPool2d(kernel_size=2, stride=2)
# Encoder
self.encoder = nn.ModuleList()
for n_channels in encoder_in_channel_list: # 64, 128, 256
self.encoder.append(EncoderBlock(n_channels, **kwargs))
# Bottleneck
self.bottleneck = ConvBlock(
encoder_in_channel_list[-1] * 2,
decoder_in_channel_list[0],
**kwargs
)
# Decoder
self.decoder = nn.ModuleList()
for n_channels in decoder_in_channel_list: # 1024, 512, 256, 128
self.decoder.append(DecoderBlock(n_channels, **kwargs))
# Last convolution layer that output the predicted mask consisting n channels where n is
# the number of classes [+ 1 (for background), for multiclass segmentation]
self.out_conv = nn.Conv2d(decoder_in_channel_list[-1] // 2, out_channels, kernel_size=1)
def forward(self, X):
skip = []
out = self.in_conv(X)
skip.append(out)
out = self.pool_after_in_conv(out)
for encoder_block in self.encoder:
out, s = encoder_block(out)
skip.append(s)
out = self.bottleneck(out)
for decoder_block in self.decoder:
s = skip.pop()
out = decoder_block(out, s)
out.shape
out = self.out_conv(out)
return out
encoder_in_channel_list = [64, 128, 256]
decoder_in_channel_list = [1024, 512, 256, 128]
model = UNet(3, encoder_in_channel_list, decoder_in_channel_list, 24, kernel_size=3, padding='same', bias=False).to(device)
The training loop for the model is
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn,
accuracy_fn,
device: torch.device,
dataloader: torch.utils.data.DataLoader
):
agg_loss, agg_accuracy = 0, 0
model.train()
for X, y in dataloader:
if torch.cuda.is_available():
torch.cuda.empty_cache()
X, y = X.to(device), y.to(device).squeeze().long() # Remove the channel dimension, [N, C, H, W] --> [N, H, W], as this is required by Cross Entropy Loss
optimizer.zero_grad()
pred = model(X) # [N, C, H, W]
loss = loss_fn(pred, y)
agg_loss += loss
accuracy = accuracy_fn(pred, one_hot(y, num_classes = 24).permute([0, 3, 1, 2])) # [N, H, W, C] --> [N, C, H, W]
agg_accuracy += accuracy
loss.backward()
del pred, loss
optimizer.step()
num_batches = len(dataloader)
return (agg_loss / num_batches).item(), (agg_accuracy / num_batches).item()
def validate(
model: torch.nn.Module,
loss_fn,
accuracy_fn,
device: torch.device,
dataloader: torch.utils.data.DataLoader
):
agg_loss, agg_accuracy = 0, 0
model.eval()
with torch.inference_mode():
for X, y in dataloader:
if torch.cuda.is_available():
torch.cuda.empty_cache()
X, y = X.to(device), y.to(device).squeeze().long() # Remove the channel dimension, [N, C, H, W] --> [N, H, W], as this is required by Cross Entropy Loss
pred = model(X) # [N, C, H, W]
loss = loss_fn(pred, y)
accuracy = accuracy_fn(pred, one_hot(y, num_classes = 24).permute([0, 3, 1, 2])) # [N, H, W, C] --> [N, C, H, W]
agg_loss += loss
agg_accuracy += accuracy
del pred, loss
num_batches = len(dataloader)
return (agg_loss / num_batches).item(), (agg_accuracy / num_batches).item()
def run(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn,
accuracy_fn,
device: torch.device,
train_dataloader: torch.utils.data.DataLoader,
val_dataloader: torch.utils.data.DataLoader,
epochs: int,
early_stop: EarlyStop,
verbose_after_every_n_epoch: int = 5
):
t_loss, t_accuracy = [], []
v_loss, v_accuracy = [], []
for epoch in tqdm(range(1, epochs + 1)):
loss, accuracy = train(
model,
optimizer,
loss_fn,
accuracy_fn,
device,
train_dataloader
)
t_loss.append(loss)
t_accuracy.append(accuracy)
loss, accuracy = validate(
model,
loss_fn,
accuracy_fn,
device,
val_dataloader
)
v_loss.append(loss)
v_accuracy.append(accuracy)
if epoch % verbose_after_every_n_epoch == 0:
print(f"Epoch {epoch}\n------------")
print(f"Train Loss: {t_loss[-1]:.2f}\tValidation Loss: {v_loss[-1]:.2f}")
print(f"Train Accuracy: {t_accuracy[-1]:.2f}\tValidaton Accuracy: {v_accuracy[-1]:.2f}\n\n")
early_stop(loss, model)
if early_stop.stop:
break
return t_loss, v_loss, t_accuracy, v_accuracy
I am using nn.CrossEntropyLoss()
as my loss_fn
and accuracy_fn
is torchmetric.segmentation.DiceScore(num_classes=24, include_background=True, input_format='one-hot', average='macro')
The prediction made by the model has logits for all 4 channel but when I am trying argmax(dim=0)
I am getting only zeros, that means the logits for the first channel is highest amongst all. I do not understand what is the issue here. Please help!