Hi! I’m trying to build a U-Net model, and I want it to predict breast tumors in mammography, so I only have two classes. For a small dataset and a few epochs, it just finds the mask of the whole breast out of the background, and not the tumors, and for larger training dataset and more epochs, after a point it just stop predicting as I print the prediction in each epoch, and everything turns 0 or 1. How am I suppose to find what is wrong?
# the U-Net is adapted from one of @ptrblck posts!
import torch
from torch import nn
import torch.nn.functional as F
import os
import random
import pydicom
from PIL import Image
import numpy as np
import torch.utils.data as utils_data
from os import listdir
from os.path import isfile, join
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.utils.data as data
import matplotlib.pyplot as plt
import cv2
def img_to_numpy(img_path):
"""Reads general image path and returns np.array normalized to 1
"""
mode_to_bits = {'1':1, 'L':8, 'P':8, 'RGB':8, 'RGBA':32, 'CMYK':32, 'YCbCr':24, 'I':16, 'F':32}
img = Image.open(open(img_path, 'rb'))
bits = mode_to_bits[img.mode]
img = np.array(img, dtype=float)
norm = 2**bits-1
img /= norm
return img
def dicom_to_numpy(img_path):
"""Reads DICOM image path and returns np.array normalized to 1
"""
img = pydicom.dcmread(img_path)
bits = img.BitsAllocated
img = img.pixel_array
img = img.astype(float)
norm = 2**bits-1
img /= norm
return img
class BaseConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding,
stride):
super(BaseConv, self).__init__()
self.act = nn.ReLU()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,
stride)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,
padding, stride)
def forward(self, x):
x = self.act(self.conv1(x))
x = self.act(self.conv2(x))
return x
class DownConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding,
stride):
super(DownConv, self).__init__()
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv_block = BaseConv(in_channels, out_channels, kernel_size,
padding, stride)
def forward(self, x):
x = self.pool1(x)
x = self.conv_block(x)
return x
class UpConv(nn.Module):
def __init__(self, in_channels, in_channels_skip, out_channels,
kernel_size, padding, stride):
super(UpConv, self).__init__()
self.conv_trans1 = nn.ConvTranspose2d(
in_channels, in_channels, kernel_size=2, padding=0, stride=2)
self.conv_block = BaseConv(
in_channels=in_channels + in_channels_skip,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride)
def forward(self, x, x_skip):
x = self.conv_trans1(x)
x = torch.cat((x, x_skip), dim=1)
x = self.conv_block(x)
return x
class UNet(nn.Module):
def __init__(self, in_channels, out_channels, n_class, kernel_size,
padding, stride):
super(UNet, self).__init__()
self.init_conv = BaseConv(in_channels, out_channels, kernel_size,
padding, stride)
self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,
padding, stride)
self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,
padding, stride)
self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,
padding, stride)
self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,
kernel_size, padding, stride)
self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,
kernel_size, padding, stride)
self.up1 = UpConv(2 * out_channels, out_channels, out_channels,
kernel_size, padding, stride)
self.out = nn.Conv2d(out_channels, n_class, kernel_size, padding, stride)
def forward(self, x):
# Encoder
x = self.init_conv(x)
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
# Decoder
x_up = self.up3(x3, x2)
x_up = self.up2(x_up, x1)
x_up = self.up1(x_up, x)
x_out = F.sigmoid(self.out(x_up))
return x_out
class MyDataSet(utils_data.Dataset):
def __init__(self, root_dir, image_dir, mask_dir, label, img_transform=None, mask_transform=None):
self.dataset_path = root_dir
self.image_dir = image_dir
self.mask_dir = mask_dir
self.img_transform = img_transform
self.mask_transform = mask_transform
mask_full_path = os.path.join(self.dataset_path, self.mask_dir)
self.mask_file_list = [f for f in listdir(mask_dir) if isfile(join(mask_dir, f))]
random.shuffle(self.mask_file_list)
self.mapping = {
0.007782101167315175: 0,
0.5019455252918288: 0,
0.9961089494163424: 1
}
def mask_to_class(self, mask):
for k in self.mapping:
mask[mask == k] = self.mapping[k]
return mask
def blockshaped(arr, nrows, ncols):
h, w = arr.shape
return (arr.reshape(h//nrows, nrows, -1, ncols).swapaxes(1,2).reshape(-1, nrows, ncols))
def __getitem__(self, index):
file_name = self.mask_file_list[index]
img_name = os.path.join(self.dataset_path, self.image_dir, file_name.replace(".png", ".dcm"))
mask_name = os.path.join(self.dataset_path, self.mask_dir, file_name)
image = dicom_to_numpy(img_name)
mask = img_to_numpy(mask_name)
if self.img_transform:
image = cv2.resize(image, (256, 256))
if self.mask_transform:
mask = cv2.resize(mask, (256, 256))
mask = self.mask_to_class(mask)
unique, counts = np.unique(mask, return_counts=True)
dict(zip(unique, counts))
image = blockshaped(image, 256, 256)
mask = blockshaped(mask, 256, 256)
labels = []
img = []
for j in range(len(mask)):
labels.append(torch.from_numpy(mask[j]))
for j in range(len(image)):
img.append(torch.from_numpy(image[j]))
# for patches, instead of one single image and mask, an array of images, and an array of masks will be passed.
return img, labels
def __len__(self):
return len(self.mask_file_list)
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def blockshaped(arr, nrows, ncols):
h, w = arr.shape
return (arr.reshape(h//nrows, nrows, -1, ncols).swapaxes(1,2).reshape(-1, nrows, ncols))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform_pipeline = transforms.Compose([
transforms.ToPILImage(mode=None),
transforms.Scale((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform_pipeline_mask = transforms.Compose([
transforms.ToPILImage(mode=None),
transforms.Scale((1024, 1024)),
transforms.ToTensor(),
])
image_dir = 'INbreast/test_mass/mass'
mask_dir = 'INbreast/test_mass/mask'
label = 'mass'
traindir = '/gdrive/My Drive/'
testdir = '/gdrive/My Drive/'
batch_size = 1
workers = 2
train_data = MyDataSet(traindir, image_dir, mask_dir, label, img_transform=transform_pipeline,
mask_transform=transform_pipeline_mask)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=workers,
pin_memory=True)
test_data = MyDataSet(testdir, image_dir, mask_dir, label, img_transform=None,
mask_transform=None)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=workers,
pin_memory=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1,
out_channels=64,
n_class=2,
kernel_size=3,
padding=1,
stride=1).to(device)
optimizer = torch.optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
dataloader = train_loader
listt = []
epochs = 10
def mask_to_class(mask):
mapping = {
510: 0,
65280: 1
}
for k in mapping:
mask[mask == k] = mapping[k]
# print("1")
return mask
for epochs in range(5):
print(epochs)
running_loss = 0.0
i = 0
for a, b in train_loader: # for X, y in dataloader:
j = 0
for j in range(len(a)):
X = a[j].float()
y = b[j].float()
X = X.to(device) # [N, 1, H, W]
y = y.to(device, dtype=torch.int64) # [N, H, W] with class indices (0, 1)
X = X.unsqueeze(0)
prediction = model(X) # [N, 2, H, W]
print("y:", y.shape)
print("predictiong:", prediction.shape)
loss = F.cross_entropy(prediction, y.long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
prediction = prediction.squeeze(0)
unique, counts = np.unique(y.byte().cpu().numpy(), return_counts=True)
dict(zip(unique, counts))
print("target:")
print(dict(zip(unique, counts)))
unique, counts = np.unique(prediction[0].byte().cpu().numpy(), return_counts=True)
dict(zip(unique, counts))
print("prediction:")
print(dict(zip(unique, counts)))
unique, counts = np.unique(prediction[1].byte().cpu().numpy(), return_counts=True)
print(dict(zip(unique, counts)))
running_loss += loss.item()
listt.append(running_loss / 2000)
running_loss += loss.item()
scheduler.step(running_loss)
running_loss = 0.0
scheduler.step(running_loss)
plt.plot(listt)
plt.show()
print('Finished Training')