Hi,
I am trying to setup a fully convolutional network for segmentation of my microscopy images (Dropbox - Error - Simplify your life). For each image I have the binary mask as well as the raw image from the microscope. I have only one class for the images, and I am only interested in pixel segmentation.
I understand that for loading my own dataset I need to create a custom torch.utils.data.dataset
class. So I made an attempt on this. Then I proceeded with making the torch.utils.data.DataLoader
object from this class. I make my model based on GitHub - wkentaro/pytorch-fcn: PyTorch Implementation of Fully Convolutional Networks. (Training code to reproduce the original result is available.). I noticed that when I start training my model, the progress gets stuck at 0%. When I looked into why this is, I realized that for some reason when I try to run a loop (for or enumerate) over my DataLoader
objects (train_loader
, val_loader
), the scripts gets stuck. I wonder if anyone can help me what am I doing wrong here? Thank you very much!
Abbas Jariani
import datetime
import os
import os.path as osp
import fcn
import numpy as np
import torch
import torch.nn as nn
#import torchfcn
import torchvision
##################################
#import datetime
from distutils.version import LooseVersion
import math
#import os
#import os.path as osp
import shutil
import pytz
import scipy.misc
from torch.autograd import Variable
import torch.nn.functional as F
import tqdm
import random
from PIL import Image
import torch.utils.data as utils_data
###################################def get_upsampling_weight(in_channels, out_channels, kernel_size):
“”“Make a 2D bilinear kernel suitable for upsampling”“”
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) *
(1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
dtype=np.float64)
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.from_numpy(weight).float()class FCN32s(nn.Module):
pretrained_model = \ osp.expanduser('~/data/models/pytorch/fcn32s_from_caffe.pth') @classmethod def download(cls): return fcn.data.cached_download( url='http://drive.google.com/uc?id=0B9P1L--7Wd2vM2oya3k0Zlgtekk', path=cls.pretrained_model, md5='8acf386d722dc3484625964cbe2aba49', ) def __init__(self, n_class=21): super(FCN32s, self).__init__() # conv1 self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) self.relu1_1 = nn.ReLU(inplace=True) self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) self.relu1_2 = nn.ReLU(inplace=True) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 # conv2 self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) self.relu2_1 = nn.ReLU(inplace=True) self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) self.relu2_2 = nn.ReLU(inplace=True) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 # conv3 self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) self.relu3_1 = nn.ReLU(inplace=True) self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) self.relu3_2 = nn.ReLU(inplace=True) self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) self.relu3_3 = nn.ReLU(inplace=True) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 # conv4 self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) self.relu4_1 = nn.ReLU(inplace=True) self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) self.relu4_2 = nn.ReLU(inplace=True) self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) self.relu4_3 = nn.ReLU(inplace=True) self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 # conv5 self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_1 = nn.ReLU(inplace=True) self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_2 = nn.ReLU(inplace=True) self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_3 = nn.ReLU(inplace=True) self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 # fc6 self.fc6 = nn.Conv2d(512, 4096, 7) self.relu6 = nn.ReLU(inplace=True) self.drop6 = nn.Dropout2d() # fc7 self.fc7 = nn.Conv2d(4096, 4096, 1) self.relu7 = nn.ReLU(inplace=True) self.drop7 = nn.Dropout2d() self.score_fr = nn.Conv2d(4096, n_class, 1) self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, bias=False) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.zero_() if m.bias is not None: m.bias.data.zero_() if isinstance(m, nn.ConvTranspose2d): assert m.kernel_size[0] == m.kernel_size[1] initial_weight = get_upsampling_weight( m.in_channels, m.out_channels, m.kernel_size[0]) m.weight.data.copy_(initial_weight) def forward(self, x): h = x h = self.relu1_1(self.conv1_1(h)) h = self.relu1_2(self.conv1_2(h)) h = self.pool1(h) h = self.relu2_1(self.conv2_1(h)) h = self.relu2_2(self.conv2_2(h)) h = self.pool2(h) h = self.relu3_1(self.conv3_1(h)) h = self.relu3_2(self.conv3_2(h)) h = self.relu3_3(self.conv3_3(h)) h = self.pool3(h) h = self.relu4_1(self.conv4_1(h)) h = self.relu4_2(self.conv4_2(h)) h = self.relu4_3(self.conv4_3(h)) h = self.pool4(h) h = self.relu5_1(self.conv5_1(h)) h = self.relu5_2(self.conv5_2(h)) h = self.relu5_3(self.conv5_3(h)) h = self.pool5(h) h = self.relu6(self.fc6(h)) h = self.drop6(h) h = self.relu7(self.fc7(h)) h = self.drop7(h) h = self.score_fr(h) h = self.upscore(h) h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() return h def copy_params_from_vgg16(self, vgg16): features = [ self.conv1_1, self.relu1_1, self.conv1_2, self.relu1_2, self.pool1, self.conv2_1, self.relu2_1, self.conv2_2, self.relu2_2, self.pool2, self.conv3_1, self.relu3_1, self.conv3_2, self.relu3_2, self.conv3_3, self.relu3_3, self.pool3, self.conv4_1, self.relu4_1, self.conv4_2, self.relu4_2, self.conv4_3, self.relu4_3, self.pool4, self.conv5_1, self.relu5_1, self.conv5_2, self.relu5_2, self.conv5_3, self.relu5_3, self.pool5, ] for l1, l2 in zip(vgg16.features, features): if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): assert l1.weight.size() == l2.weight.size() assert l1.bias.size() == l2.bias.size() l2.weight.data = l1.weight.data l2.bias.data = l1.bias.data for i, name in zip([0, 3], ['fc6', 'fc7']): l1 = vgg16.classifier[i] l2 = getattr(self, name) l2.weight.data = l1.weight.data.view(l2.weight.size()) l2.bias.data = l1.bias.data.view(l2.bias.size())
##################################################
def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return histdef label_accuracy_score(label_trues, label_preds, n_class):
“”"Returns accuracy score evaluation result.- overall accuracy - mean accuracy - mean IU - fwavacc """ hist = np.zeros((n_class, n_class)) for lt, lp in zip(label_trues, label_preds): hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) acc = np.diag(hist).sum() / hist.sum() with np.errstate(divide='ignore', invalid='ignore'): acc_cls = np.diag(hist) / hist.sum(axis=1) acc_cls = np.nanmean(acc_cls) with np.errstate(divide='ignore', invalid='ignore'): iu = np.diag(hist) / ( hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) ) mean_iu = np.nanmean(iu) freq = hist.sum(axis=1) / hist.sum() fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() return acc, acc_cls, mean_iu, fwavacc
def cross_entropy2d(input, target, weight=None, size_average=True):
# input: (n, c, h, w), target: (n, h, w)
n, c, h, w = input.size()
# log_p: (n, c, h, w)
if LooseVersion(torch.version) < LooseVersion(‘0.3’):
# ==0.2.X
log_p = F.log_softmax(input)
else:
# >=0.3
log_p = F.log_softmax(input, dim=1)
# log_p: (nhw, c)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
log_p = log_p.view(-1, c)
# target: (nhw,)
mask = target >= 0
target = target[mask]
loss = F.nll_loss(log_p, target, weight=weight, reduction=‘sum’)
if size_average:
loss /= mask.data.sum()
return lossclass Trainer(object):
def __init__(self, cuda, model, optimizer, train_loader, val_loader, out, max_iter, size_average=False, interval_validate=None): self.cuda = cuda self.model = model self.optim = optimizer self.train_loader = train_loader self.val_loader = val_loader self.timestamp_start = \ datetime.datetime.now(pytz.timezone('Asia/Tokyo')) self.size_average = size_average if interval_validate is None: self.interval_validate = len(self.train_loader) else: self.interval_validate = interval_validate self.out = out if not osp.exists(self.out): os.makedirs(self.out) self.log_headers = [ 'epoch', 'iteration', 'train/loss', 'train/acc', 'train/acc_cls', 'train/mean_iu', 'train/fwavacc', 'valid/loss', 'valid/acc', 'valid/acc_cls', 'valid/mean_iu', 'valid/fwavacc', 'elapsed_time', ] if not osp.exists(osp.join(self.out, 'log.csv')): with open(osp.join(self.out, 'log.csv'), 'w') as f: f.write(','.join(self.log_headers) + '\n') self.epoch = 0 self.iteration = 0 self.max_iter = max_iter self.best_mean_iu = 0 def validate(self): training = self.model.training self.model.eval() n_class = len(self.val_loader.dataset.class_names) val_loss = 0 visualizations = [] label_trues, label_preds = [], [] for batch_idx, (data, target) in tqdm.tqdm( enumerate(self.val_loader), total=len(self.val_loader), desc='Valid iteration=%d' % self.iteration, ncols=80, leave=False): if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) with torch.no_grad(): score = self.model(data) loss = cross_entropy2d(score, target, size_average=self.size_average) loss_data = loss.data.item() if np.isnan(loss_data): raise ValueError('loss is nan while validating') val_loss += loss_data / len(data) imgs = data.data.cpu() lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu() for img, lt, lp in zip(imgs, lbl_true, lbl_pred): img, lt = self.val_loader.dataset.untransform(img, lt) label_trues.append(lt) label_preds.append(lp) if len(visualizations) < 9: viz = fcn.utils.visualize_segmentation( lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class) visualizations.append(viz) metrics = label_accuracy_score( label_trues, label_preds, n_class) out = osp.join(self.out, 'visualization_viz') if not osp.exists(out): os.makedirs(out) out_file = osp.join(out, 'iter%012d.jpg' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations)) val_loss /= len(self.val_loader) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration] + [''] * 5 + \ [val_loss] + list(metrics) + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') mean_iu = metrics[2] is_best = mean_iu > self.best_mean_iu if is_best: self.best_mean_iu = mean_iu torch.save({ 'epoch': self.epoch, 'iteration': self.iteration, 'arch': self.model.__class__.__name__, 'optim_state_dict': self.optim.state_dict(), 'model_state_dict': self.model.state_dict(), 'best_mean_iu': self.best_mean_iu, }, osp.join(self.out, 'checkpoint.pth.tar')) if is_best: shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'), osp.join(self.out, 'model_best.pth.tar')) if training: self.model.train() def train_epoch(self): self.model.train() n_class = len(self.train_loader.dataset.class_names) #print("XXXXX") #print(list(enumerate((self.train_loader)) )) #print("XXXXX") for batch_idx, (data, target) in tqdm.tqdm( enumerate(self.train_loader), total=len(self.train_loader), desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): iteration = batch_idx + self.epoch * len(self.train_loader) if self.iteration != 0 and (iteration - 1) != self.iteration: continue # for resuming self.iteration = iteration if self.iteration % self.interval_validate == 0: self.validate() assert self.model.training if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) ############### # #print(type(data)) #print(type(target)) #print("XXXXXX") ######### self.optim.zero_grad() score = self.model(data) loss = cross_entropy2d(score, target, size_average=self.size_average) loss /= len(data) loss_data = loss.data.item() if np.isnan(loss_data): raise ValueError('loss is nan while training') loss.backward() self.optim.step() metrics = [] lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu().numpy() acc, acc_cls, mean_iu, fwavacc = \ label_accuracy_score( lbl_true, lbl_pred, n_class=n_class) metrics.append((acc, acc_cls, mean_iu, fwavacc)) metrics = np.mean(metrics, axis=0) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration] + [loss_data] + \ metrics.tolist() + [''] * 5 + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') if self.iteration >= self.max_iter: break def train(self): max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader))) for epoch in tqdm.trange(self.epoch, max_epoch, desc='Train', ncols=80): self.epoch = epoch self.train_epoch() if self.iteration >= self.max_iter: break
def get_parameters(model, bias=False):
import torch.nn as nn
modules_skipped = (
nn.ReLU,
nn.MaxPool2d,
nn.Dropout2d,
nn.Sequential,
FCN32s,
#torchfcn.models.FCN16s,
#torchfcn.models.FCN8s,
)
for m in model.modules():
if isinstance(m, nn.Conv2d):
if bias:
yield m.bias
else:
yield m.weight
elif isinstance(m, nn.ConvTranspose2d):
# weight is frozen because it is just a bilinear upsampling
if bias:
assert m.bias is None
elif isinstance(m, modules_skipped):
continue
else:
raise ValueError(‘Unexpected module: %s’ % str(m))class MyDataSet(utils_data.Dataset):
def init(self, root_dir, image_dir, mask_dir, label, transform=None):
self.dataset_path = root_dir
self.image_dir = image_dir
self.mask_dir = os.path.join(mask_dir, label)
self.transform = transform
mask_full_path = os.path.join(self.dataset_path, self.mask_dir)
self.mask_file_list = [f for f in os.listdir(mask_full_path) if osp.isfile(osp.join(mask_full_path, f))]
random.shuffle(self.mask_file_list)
self.class_names = [‘1’]def getitem(self, index):
file_name = self.mask_file_list[index].rsplit(‘.’, 1)[0]
img_name = os.path.join(self.dataset_path, self.image_dir, file_name+‘.jpg’)
mask_name = os.path.join(self.dataset_path, self.mask_dir, self.mask_file_list[index])
image = Image.open(img_name)
mask = Image.open(mask_name)
image = np.array(image)
image = np.rollaxis(image, 2, 0)
image = np.array(image).astype(np.float32)
labels = np.array(mask).astype(np.uint8)
sample = {‘image’: image, ‘labels’: labels}if self.transform: sample = self.transform(sample) return sample
def len(self):
return len(self.mask_file_list)normalize = torchvision.transforms.Normalize(mean=1, std= 0.3)
transform_pipeline = torchvision.transforms.Compose([
#torchvision.transforms.RandomResizedCrop(12),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
#normalize,
])#mask_dir = ‘\\Moana\oc1\OcellO Projects\VitroScan\190430_create_trainset_deepl\OVCA2018_08\curated_masks_from_oline’
#tritc_dir = ‘\\Moana\oc1\OcellO Projects\VitroScan\190430_create_trainset_deepl\OVCA2018_08\curated_tritc_from_oline’train_tritc = ‘\\Moana\oc1\OcellO Projects\VitroScan\190430_create_trainset_deepl\OVCA2018_08\train_tritc\’
train_masks = ‘\\Moana\oc1\OcellO Projects\VitroScan\190430_create_trainset_deepl\OVCA2018_08\train_masks\’
valid_tritc = ‘\\Moana\oc1\OcellO Projects\VitroScan\190430_create_trainset_deepl\OVCA2018_08\valid_tritc\’
valid_masks = ‘\\Moana\oc1\OcellO Projects\VitroScan\190430_create_trainset_deepl\OVCA2018_08\valid_masks\’
#train_data = MyDataSet( args.traindir, args.image_dir, args.mask_dir, args.label)train_data = MyDataSet( ‘D:\myTools\pytorch-fcn-master\examples\voc\out_FCN32s’, train_tritc, train_masks, ‘1’)
valid_data = MyDataSet( ‘D:\myTools\pytorch-fcn-master\examples\voc\out_FCN32s’, valid_tritc, valid_masks, ‘1’)#sampler_train = DummySampler(train_data)
#sampler_valid = DummySampler(valid_data)train_loader = torch.utils.data.DataLoader( train_data,
batch_size=1, shuffle=False,
num_workers=5, pin_memory=True)val_loader = torch.utils.data.DataLoader( valid_data,
batch_size=1, shuffle=False,
num_workers=5, pin_memory=True)################################
model = FCN32s(n_class=1)
#model = torchfcn.models.FCN8s(n_class=1)
#vgg16 = torchfcn.models.VGG16(pretrained=True)
#model.copy_params_from_vgg16(vgg16)
model = model.cuda()optim = torch.optim.SGD(
[
{‘params’: get_parameters(model, bias=False)},
{‘params’: get_parameters(model, bias=True),
‘lr’: 1.0e-10 * 2, ‘weight_decay’: 0},
],
lr=1.0e-10,
momentum=0.99,
weight_decay=0.0005)trainer = Trainer(
cuda=True,
model=model,
optimizer=optim,
train_loader=train_loader,
val_loader=val_loader,
out=‘D:\myTools\pytorch-fcn-master\examples\voc\out_FCN32s’,
max_iter=10000,
interval_validate=4000,
)start_epoch = 0
start_iteration = 0
trainer.epoch = start_epoch
trainer.iteration = start_iteration##################
deviceid = 0
os.environ[‘CUDA_VISIBLE_DEVICES’] = “%d”%deviceid################
trainer.train()