Hi,
I am trying to setup a fully convolutional network to segment my microscopy images. You can see the image set I use from this link. There are four folders: original image for training, binary masks for training and the same pair for validation.
Thanks to this active forum I got the script running fine on cpu for many epochs, however when I run it on gpu it runs only for the very first validation round, however after that when the script reaches the command self.optim.step()
, I get a cuda memory error:
RuntimeError: CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 2.00 GiB total capacity; 1.06 GiB already allocated; 123.46 MiB free; 26.89 MiB cached)
Now I wonder if and where there is a memory leakage, or whether is this an honest to god out-of-memory error because of my hardware limitation? Thank you!
Sincerely,
import glob
import torch
from torch.utils.data.dataset import Dataset # For custom data-sets
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import torchvision
import matplotlib.pyplot as plt
from torch._C import *
import torch.nn as nn
import math
import os.path as osp
import scipy.misc
import tqdm
import random
import datetime
import os
import pytz
import shutil
from distutils.version import LooseVersion
#import os
#import os.path as osp
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as utils_data
import gc
#####################
train_tritc_dir = 'train_tritc\\'
train_masks_dir = 'train_masks\\1\\'
valid_tritc_dir = 'valid_tritc\\'
valid_masks_dir = 'valid_masks\\1\\'
train_tritcs = glob.glob('\\\\Moana\\oc1\\OcellO Projects\\VitroScan\\190430_create_trainset_deepl\\OVCA2018_08\\'+train_tritc_dir+'*png')
train_masks = glob.glob('\\\\Moana\\oc1\\OcellO Projects\\VitroScan\\190430_create_trainset_deepl\\OVCA2018_08\\'+train_masks_dir+'*png')
valid_tritcs = glob.glob('\\\\Moana\\oc1\\OcellO Projects\\VitroScan\\190430_create_trainset_deepl\\OVCA2018_08\\'+valid_tritc_dir+'*png')
valid_masks = glob.glob('\\\\Moana\\oc1\\OcellO Projects\\VitroScan\\190430_create_trainset_deepl\\OVCA2018_08\\'+valid_masks_dir+'*png')
###########################################################
class CustomDataset(Dataset):
def __init__(self, image_paths, target_paths, train=True): # initial logic happens like transform
self.image_paths = image_paths
self.target_paths = target_paths
self.transforms = transforms.ToTensor()
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
image = np.array(image).astype(np.float32)
mask = Image.open(self.target_paths[index])
mask = np.array(mask).astype(np.uint8)
t_image = self.transforms(image)
mask = self.transforms(mask)
return t_image, mask
def __len__(self): # return count of sample we have
return len(self.image_paths)
##############################################
#######################################################
class FCN32s(nn.Module):
def __init__(self, n_class=2):
super(FCN32s, self).__init__()
# conv1
self.conv1_1 = nn.Conv2d(1, 64, 3, padding=100)
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
self.relu1_2 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2
# conv2
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
self.relu2_2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4
# conv3
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_2 = nn.ReLU(inplace=True)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8
# conv4
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_2 = nn.ReLU(inplace=True)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_3 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16
# conv5
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_2 = nn.ReLU(inplace=True)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_3 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
# fc6
self.fc6 = nn.Conv2d(512, 4096, 7)
self.relu6 = nn.ReLU(inplace=True)
self.drop6 = nn.Dropout2d()
# fc7
self.fc7 = nn.Conv2d(4096, 4096, 1)
self.relu7 = nn.ReLU(inplace=True)
self.drop7 = nn.Dropout2d()
self.score_fr = nn.Conv2d(4096, n_class, 1)
self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32,
bias=False)
#self._initialize_weights()
def forward(self, x):
h = x
h = self.relu1_1(self.conv1_1(h))
h = self.relu1_2(self.conv1_2(h))
h = self.pool1(h)
h = self.relu2_1(self.conv2_1(h))
h = self.relu2_2(self.conv2_2(h))
h = self.pool2(h)
h = self.relu3_1(self.conv3_1(h))
h = self.relu3_2(self.conv3_2(h))
h = self.relu3_3(self.conv3_3(h))
h = self.pool3(h)
h = self.relu4_1(self.conv4_1(h))
h = self.relu4_2(self.conv4_2(h))
h = self.relu4_3(self.conv4_3(h))
h = self.pool4(h)
h = self.relu5_1(self.conv5_1(h))
h = self.relu5_2(self.conv5_2(h))
h = self.relu5_3(self.conv5_3(h))
h = self.pool5(h)
h = self.relu6(self.fc6(h))
h = self.drop6(h)
h = self.relu7(self.fc7(h))
h = self.drop7(h)
h = self.score_fr(h)
h = self.upscore(h)
h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous()
return h
##########################
def get_parameters(model, bias=False):
import torch.nn as nn
modules_skipped = (
nn.ReLU,
nn.MaxPool2d,
nn.Dropout2d,
nn.Sequential,
FCN32s,
#torchfcn.models.FCN16s,
#torchfcn.models.FCN8s,
)
for m in model.modules():
if isinstance(m, nn.Conv2d):
if bias:
yield m.bias
else:
yield m.weight
elif isinstance(m, nn.ConvTranspose2d):
# weight is frozen because it is just a bilinear upsampling
if bias:
assert m.bias is None
elif isinstance(m, modules_skipped):
continue
else:
raise ValueError('Unexpected module: %s' % str(m))
################################################################
class Trainer(object):
def __init__(self, cuda, model, optimizer,
train_loader, val_loader, out, max_iter,
size_average=False, interval_validate=None):
self.cuda = cuda
self.model = model
self.optim = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.timestamp_start = \
datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
self.size_average = size_average
if interval_validate is None:
self.interval_validate = len(self.train_loader)
else:
self.interval_validate = interval_validate
self.out = out
if not osp.exists(self.out):
os.makedirs(self.out)
self.log_headers = [
'epoch',
'iteration',
'train/loss',
'train/acc',
'train/acc_cls',
'train/mean_iu',
'train/fwavacc',
'valid/loss',
'valid/acc',
'valid/acc_cls',
'valid/mean_iu',
'valid/fwavacc',
'elapsed_time',
]
if not osp.exists(osp.join(self.out, 'log.csv')):
with open(osp.join(self.out, 'log.csv'), 'w') as f:
f.write(','.join(self.log_headers) + '\n')
self.epoch = 0
self.iteration = 0
self.max_iter = max_iter
self.best_mean_iu = 0
def validate(self):
training = self.model.training
self.model.eval()
#n_class = len(self.val_loader.dataset.class_names)
n_class = 2
val_loss = 0
#visualizations = []
label_trues, label_preds = [], []
#for batch_idx, (data, target) in tqdm.tqdm(
for batch_idx, cur_data_dic in tqdm.tqdm(
enumerate(self.val_loader), total=len(self.val_loader),
desc='Valid iteration=%d' % self.iteration, ncols=80,
leave=False):
#data = cur_data_dic['image']
data = cur_data_dic[0]
target = cur_data_dic[1]
#target = cur_data_dic['labels']
#
if self.cuda:
#
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
with torch.no_grad():
score = self.model(data)
loss = cross_entropy2d(score, target,
size_average=self.size_average)
loss_data = loss.data.item()
if np.isnan(loss_data):
raise ValueError('loss is nan while validating')
val_loss += loss_data / len(data)
imgs = data.data.cpu()
lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
lbl_true = target.data.cpu()
for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
#XXXXXXXXXX#
img = img.numpy()
img = img.transpose(1, 2, 0)
img = img.astype(np.uint8)
img = img[:, :, ::-1]
lt = lt.numpy()
#XXXXXXXXXX#
#print("XXX")
#print(img.shape)
#print(lt.shape)
#print(lp.shape)
#img, lt = self.val_loader.dataset.untransform(img, lt)
label_trues.append(lt)
label_preds.append(lp)
#if len(visualizations) < 9:
# viz = fcn.utils.visualize_segmentation(
# lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class)
# visualizations.append(viz)
metrics = label_accuracy_score(
label_trues, label_preds, n_class)
out = osp.join(self.out, 'visualization_viz')
if not osp.exists(out):
os.makedirs(out)
#out_file = osp.join(out, 'iter%012d.jpg' % self.iteration)
#scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations))
val_loss /= len(self.val_loader)
with open(osp.join(self.out, 'log.csv'), 'a') as f:
elapsed_time = (
datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
self.timestamp_start).total_seconds()
log = [self.epoch, self.iteration] + [''] * 5 + \
[val_loss] + list(metrics) + [elapsed_time]
log = map(str, log)
f.write(','.join(log) + '\n')
mean_iu = metrics[2]
is_best = mean_iu > self.best_mean_iu
if is_best:
self.best_mean_iu = mean_iu
torch.save({
'epoch': self.epoch,
'iteration': self.iteration,
'arch': self.model.__class__.__name__,
'optim_state_dict': self.optim.state_dict(),
'model_state_dict': self.model.state_dict(),
'best_mean_iu': self.best_mean_iu,
}, osp.join(self.out, 'checkpoint.pth.tar'))
if is_best:
shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'),
osp.join(self.out, 'model_best.pth.tar'))
if training:
self.model.train()
def train_epoch(self):
self.model.train()
#n_class = len(self.train_loader.dataset.class_names)
n_class = 2
for batch_idx, (data, target) in tqdm.tqdm(
enumerate(self.train_loader), total=len(self.train_loader),
desc='Train epoch=%d' % self.epoch, ncols=80, leave=False):
iteration = batch_idx + self.epoch * len(self.train_loader)
if self.iteration != 0 and (iteration - 1) != self.iteration:
continue # for resuming
self.iteration = iteration
if self.iteration % self.interval_validate == 0:
self.validate()
assert self.model.training
if self.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
###############
#########
#print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
self.optim.zero_grad()
score = self.model(data)
loss = cross_entropy2d(score, target,
size_average=self.size_average)
loss /= len(data)
loss_data = loss.data.item()
if np.isnan(loss_data):
raise ValueError('loss is nan while training')
loss.backward()
"""
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
print(type(obj), obj.size())
except:
pass
"""
#print("YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY")
self.optim.step()
#print("ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZXXXX")
metrics = []
lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
lbl_true = target.data.cpu().numpy()
acc, acc_cls, mean_iu, fwavacc = \
label_accuracy_score(
lbl_true, lbl_pred, n_class=n_class)
metrics.append((acc, acc_cls, mean_iu, fwavacc))
metrics = np.mean(metrics, axis=0)
with open(osp.join(self.out, 'log.csv'), 'a') as f:
elapsed_time = (
datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
self.timestamp_start).total_seconds()
log = [self.epoch, self.iteration] + [loss_data] + \
metrics.tolist() + [''] * 5 + [elapsed_time]
log = map(str, log)
f.write(','.join(log) + '\n')
if self.iteration >= self.max_iter:
break
def train(self):
max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader)))
for epoch in tqdm.trange(self.epoch, max_epoch,
desc='Train', ncols=80):
self.epoch = epoch
self.train_epoch()
if self.iteration >= self.max_iter:
break
#################################
def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def label_accuracy_score(label_trues, label_preds, n_class):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
"""
hist = np.zeros((n_class, n_class))
for lt, lp in zip(label_trues, label_preds):
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
acc = np.diag(hist).sum() / hist.sum()
with np.errstate(divide='ignore', invalid='ignore'):
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
with np.errstate(divide='ignore', invalid='ignore'):
iu = np.diag(hist) / (
hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
)
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
return acc, acc_cls, mean_iu, fwavacc
def cross_entropy2d(input, target, weight=None, size_average=True):
#################
n, c, h, w = input.size()
# log_p: (n, c, h, w)
if LooseVersion(torch.__version__) < LooseVersion('0.3'):
# ==0.2.X
log_p = F.log_softmax(input)
else:
# >=0.3
log_p = F.log_softmax(input, dim=1)
# log_p: (n*h*w, c)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
log_p = log_p.view(-1, c)
###############
# target: (n*h*w,)
mask = target >= 0
target = target[mask]
target = target.long()
loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
if size_average:
loss /= mask.data.sum()
return loss
#####################################
model = FCN32s(n_class=2)
#Comment the line below to run on cpu
#Also change the input parameter to initialization of Trainer
model = model.cuda()
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': 1.0e-10 * 2, 'weight_decay': 0},
],
lr=1.0e-10,
momentum=0.99,
weight_decay=0.0005)
############################################################
###########################################################
train_dataset = CustomDataset(train_tritcs, train_masks)
valid_dataset = CustomDataset(valid_tritcs, valid_masks)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
###########################################
trainer = Trainer(
cuda=True,
model=model,
optimizer=optim,
train_loader=train_loader,
val_loader=valid_loader,
out='D:\\myTools\\\my_FCN_190515',
max_iter=10000,
interval_validate=4000,
)
#######################################
trainer.train()