So, this has aged a bit, but the problem persists. Here is an example output of the lightning profiler:
Click me
import os
import numpy as np
from os.path import join, splitext, basename
import torch
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from torch.utils.data import Dataset
from treecrowndelineation import TreeCrownDelineationModel
from pytorch_lightning import Trainer, LightningDataModule
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import albumentations as A
from glob import glob
from PIL import Image
from torch.utils.data import DataLoader
torch.set_float32_matmul_precision("medium")
#%%
class ISPRSDataset(Dataset):
def __init__(self, images, masks, outlines, dist_trafo, augmentation):
super().__init__()
# we assume images and labels match perfectly
self.images = images
self.masks = masks
self.outlines = outlines
self.dist_trafo = dist_trafo
self.augmentation = augmentation
def __len__(self):
return len(self.images)
def __getitem__(self, item):
image = np.array(Image.open(self.images[item]))
mask = np.array(Image.open(self.masks[item]))
outline = np.array(Image.open(self.outlines[item]))
dist_trafo = np.array(Image.open(self.dist_trafo[item]))
labels = np.stack((mask, outline, dist_trafo), axis=-1)
# return image, labels
augmented = self.augmentation(image=image, mask=labels)
return (augmented["image"].transpose((2, 0, 1)).astype(np.float32),
augmented["mask"].transpose((2, 0, 1)).astype(np.float32))
class ISPRSDataModule(LightningDataModule):
def __init__(self, images, masks, outlines, dist_trafo, augmentation, workers=6, batch_size=64):
super().__init__()
self.images = np.array(sorted(glob(join(images, "*.png")), key=lambda x: splitext(basename(x))[0]))
self.masks = np.array(sorted(glob(join(masks, "*.png")), key=lambda x: "_".join(splitext(basename(x))[0].split("_")[:-1])))
self.outlines = np.array(sorted(glob(join(outlines, "*.png")), key=lambda x: "_".join(splitext(basename(x))[0].split("_")[:-1])))
self.dist_trafo = np.array(sorted(glob(join(dist_trafo, "*.tif")), key=lambda x: splitext(basename(x))[0].split("_")[:-2]))
self.augmentation = augmentation
self.workers = workers
self.batch_size = batch_size
# sanity checks
assert len(self.images) > 0
assert len(self.images) == len(self.masks) == len(self.outlines) == len(self.dist_trafo)
# compare basenames
im_bn = np.array([splitext(basename(x))[0] for x in self.images])
m_bn = np.array(["_".join(splitext(basename(x))[0].split("_")[:-1]) for x in self.masks])
o_bn = np.array(["_".join(splitext(basename(x))[0].split("_")[:-1]) for x in self.outlines])
d_bn = np.array(["_".join(splitext(basename(x))[0].split("_")[:-2]) for x in self.dist_trafo])
assert (im_bn==m_bn).all() and (im_bn == o_bn).all() and (im_bn == d_bn).all()
def setup(self, stage: str) -> None:
train_set, val_set = torch.utils.data.random_split(np.arange(len(self.images)), (0.8, 0.2))
train_set = np.array(train_set)
val_set = np.array(val_set)
self.train_ds = ISPRSDataset(self.images[train_set],
self.masks[train_set],
self.outlines[train_set],
self.dist_trafo[train_set],
augmentation=self.augmentation)
self.val_ds = ISPRSDataset(self.images[val_set],
self.masks[val_set],
self.outlines[val_set],
self.dist_trafo[val_set],
augmentation=self.augmentation)
def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(self.train_ds, batch_size=self.batch_size, pin_memory=False, num_workers=self.workers)
def val_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(self.val_ds, batch_size=self.batch_size, pin_memory=False, num_workers=self.workers)
#%%
# base = "/data_hdd/isprs-itc-seg/new/train/"
base = "/tmp/train/"
logdir = "./log"
model_save_path = "./models"
experiment_name = "isprs"
arch = "Unet-resnet18"
width = 256
batchsize = 16
in_channels = 3
devices = 1 # number of gpus, if you have multiple
accelerator = "auto" # or gpu or cpu, see lightning docs
max_epochs = 30 + 60 - 1
lr = 3E-4
training_split = 0.8
model_name = "{}_epochs={}_lr={}_width={}_bs={}".format(arch,
max_epochs,
lr,
width,
batchsize)
#%%
###################################
# training #
###################################
logger = TensorBoardLogger(logdir,
name=experiment_name,
# version=model_name,
default_hp_metric=False)
cp = ModelCheckpoint(os.path.abspath(model_save_path) + "/" + experiment_name,
model_name + "-{epoch}",
monitor="val/loss",
save_last=True,
save_top_k=2)
callbacks = [cp, LearningRateMonitor()]
train_augmentation = A.Compose([A.RandomCrop(width, width, always_apply=True),
A.RandomRotate90(),
A.VerticalFlip()
])
# val_augmentation = A.RandomCrop(width, width, always_apply=True)
data = ISPRSDataModule(images=base+"images",
masks=base+"masks",
outlines=base+"outlines",
dist_trafo=base+"dist_trafo",
augmentation=train_augmentation)
model = TreeCrownDelineationModel(in_channels=in_channels, lr=lr)
# model = torch.compile(model)
#%%
trainer = Trainer(devices=devices,
accelerator=accelerator,
logger=logger,
callbacks=callbacks,
# checkpoint_callback=False, # set this to avoid logging into the working directory
max_epochs=max_epochs,
enable_progress_bar=False,
profiler="simple",
max_steps=50,
)
trainer.fit(model, data)
#%%
model.to("cpu")
t = torch.rand(1, in_channels, width, width, dtype=torch.float32)
model.to_torchscript(
os.path.abspath(model_save_path) + "/" + experiment_name + '/' + model_name + "_jitted.pt",
method="trace",
example_inputs=t)
One dataset getindex call takes 38ms. So assembling a batch of 64 should take around 2.4s in one process and around 400ms with 6 processes. But actually it takes around 1.7s. All images are again in a ram disk.