How to create patches / windows of image dataset?


(Bodo Kaiser) #1

I am working with the MNIBITE dataset which contains magnetic resonance (MR) and ultra sound (US) images of human brains.

Because the US only covers a small portion of the MR image I want to apply a sliding window (e.g. skimage.util.view_as_windows) over the 466x394 US and MR images so that I can remove empty patches from training.

I would now like to know where would be the best place to implement this type of preprocessing?

  1. I can’t put this into the torch.utils.data.Dataset instance as __getitem__ would then return n-patches.
  2. I can’t put this into the training loop as it the batch size would vary.

It may be best to put it somewhere right before the torch.utils.data.DataLoader but I don’t see how. Any suggestions?


(Adam Paszke) #2

Yeah that’s a tricky one. There’s no way to tell the DataLoader to dynamically skip some inputs and replace them with next ones, because it would violate the guarantee that it will always return the data in the same order as the given by the sampler.

Why is having variable batch size so bad? The only solution I can think of is to pre-process the whole dataset, and save a list of tuples: (filename, offset_x, offset_y) which would contain all non-zero patches. This would give you a sequence of all good patches for use at training time. You could add an option to preprocess it once and later pickle the data somewhere so it’s faster next time.


(Bodo Kaiser) #3

Variable batch size wouldn’t be bad per se that said the data unfortunately is nearly delta shape distributed :smiley:

As you said I think I will need to add another stage which converts the data from HDF5 to patches and then feed that data into pytorch.

Do you have a recommendation on what dataset interface to use? Did you experience any problems using the ImageFolder dataset with lots of image patches (IO wise)?

Thanks in advance!


(Pierre Antoine Ganaye) #4

Hi @bodokaiser,
I am also working on medical imaging (brain MRI), and already solved the problem you mentionned. You need to write your own Dataset, that will extract all the possible patches positions and store those. Once you iterate on this dataset you will extract and return the patches.
It has the advantage of only storing the patches positions (use a numpy array or tensor). If you have any trouble, please ask. I have implemented multiple 2d patch extraction with constant padding.


(Adam Paszke) #5

@bodokaiser no, we’ve used ImageFolder for large datasets and it worked fine, did you encounter any problems? IO wise the dataset doesn’t matter too much (it should just load a single image), but you definitely want to use a pool of workers in DataLoader. Also, I’ve heard that OpenCV2 might be faster than PIL, so you might try it out if you’re bottlenecked by the data loading.

You don’t need to save the patches ahead of time, it should be enough to store the file names and locations to retrieve them later. But if you don’t have a lot of data it might be simpler to preprocess the dataset.


(Bodo Kaiser) #6

@apaszke until now everything runs smooth :smiley: I also think my dataset is too small to cause problems (I had something like OS X IO limitations in mind regarding IO limits though easy to solve).


(Bodo Kaiser) #7

@trypag Interesting! What are you doing exactly? Do you mind if we share some ideas (bodo.kaiser at me com)?


(Pierre Antoine Ganaye) #8

sure, I emailed you!!


(Adam Paszke) #9

@bodokaiser We’re giving the tools that are fast, but we can’t do anything about the system limits ¯_(ツ)_/¯
That kind of issues you have to fix yourself.


(Bodo Kaiser) #10

When this issue is resolved in torchvision another solution to this problem may be to use torchvision.transforms.RandomCrop on both targets and inputs which should behave over the epochs as if you would train patches (maybe this approach would even be superior as it takes patches from different places).

To crop MRI and US to the same size (so we already removed most area with zero data) you can use: https://gist.github.com/bodokaiser/b5697f95336be11aa09f6e246901e0d4


#11

Hi @trypag ,
I am working on it too. A little difference, I am focus on creating patches at conv_map with multiple channels in forward(). All I know now is that torch.gather and torch.masked_select may help. Would you share some ideas with me on your work.(chungweilam@163.com).
Many thanks.


(Pierre Antoine Ganaye) #12

Hi @Tepp,
I will just copy the mail I sent to Bodo, nothing really fancy :slight_smile:

Hi,
I am working on brain MRI segmentation, I am trying to reproduce a paper based on patch based approach. So I had to implement patch 2D patch extraction for 3D image. I am wishing to share this code publicly someday, but I don’t think it will happen soon. It can offer my help to guide you if you want.
I used simple itk, you can also use nibabel to read the images. First step is to load your image, then evaluate the number of valid patch positions you can extract. This will give you an array of positions. When iterating over your dataset, you will return a patch at that position, that returns the thing you want to extract, and the corresponding label.

See you


(Nick) #13

@trypag ah man I would stay away from the patch based approach for MRI segmentation… Those methods are incredibly slow at test time and end up taking about as long as (or longer) than traditional methods with much worse accuracy. 2d Conv + Conv transpose on entire slices OR 3d conv +conv-transpose on the entire image works much better. We have ~1 second models for skull stripping (non-DL methods take ~1 min) and ~30s models for tissue segmentation (non-DL takes a few hours)… These are all using conv+Conv transpose ensembles.

Anyways, a very simple approach to sampling 2D patches from 3D MRI is to extract a random slice (see my package torchsample for that transform) then just take a random crop. You can use the “MultiSampler” sampler in torchsample to go through each image more than once in a single epoch. No need to extract EVERY possible patch haha just do it randomly at train time.

It’s still a fun exercise to reimplement them tho… which paper specifically are you interested in? I think I have implementations of most of them and can share or help you develop the right sampler.


(Pierre Antoine Ganaye) #14

@ncullen93 Yes I am not sure this is the right approach too. I am giving my first try at MRI segmentation :slight_smile:
Good point about patch based approach is that you have much more data, I am afraid using full slice will seriously reduce the dataset, did you had any problem with that ? how did you solved it ?


(Nick) #15

Yeah you kinda have more data but it’s misleading bc you don’t have more region-specific subject variability. Data is definitely an issue… I train on a lot of subjects (a few thousand) from multiple large datasets… Also, the KEY to MRI segmentation is data augmentation and applying random affine transforms specifically since images differ that way across scanners - hence why I developed torchsample. With good augmentation we have done quite well with ~200 subjects.


(Roger) #16

Hi @trypag, I am also working on medical imaging segmentation, and have implemented some algorithms mainly in Tensorflow and Caffe. I usually store my data in hdf5 files and then read those during training. Could you explain a bit more about how to “write yor own Dataset” ?
In my case I have several hdf5 files (one for each subject), and each file has several samples. I am not quite sure on how can I read these h5 files using the dataloader utility. The problem is that the number of samples is note the number of each hdf5 files, because each one has a different number of samples, and if I want to use the Dataset class, it says that I should implement the len and getitem .
Thanks!


(Meghal Dani) #17

@nculler93 I am looking for implementation of the paper titled " A general prediction model for the detection of ADHD and Autism using structural and functional MRI" where 3D patch sampling is done for MRI images and then it is passed to autoencoder to reduce the filter size. Can you suggest a method to do it all. I am badly stuck onto how to do it.