Applying a semantic segmentation classifier to a large image

Hello,

I have trained a semantic segmentation classifier (UNet-variant) and now need to apply the classifier to a to a large stack of climate data. Given the dimensions of the climate data over my study area are much larger than than the classifier input size.

I know I need to break the larger image into smaller chunks and feed them through the classifier. I also know that including some overlap during inference is generally good practice to reduce edge effects in the predictions.

Has someone produced a script that illustrates an efficient way of doing this?

Thank you!

I would try to use nn.Unfold to create the smaller patches and nn.Fold to recreate the bigger output.
Note that nn.Fold will accumulate the values in the overlapping regions, so you might want to use a manual approach with permutations and some view operations.

Let me know, if that would work for you. :slight_smile:

Thanks for the suggestion! Could you please help me fill out an example of how to use unfold and fold in conjunction.

For simplicity sake, let’s say we have an 3x2048x2048 image (my actual images are a lot larger) and a semantic segmentation model that has an input resolution of 256x256.

The following code for example doesn’t produce a tensor of 256x256 windows like I’d expect…which I assume means I am using the unfold function incorrectly:

import torch
import torch.nn.functional as f


# Load in a big image
big_image = torch.rand(1,3,2048,2048)

# Unfold the image into 3 x 256 x 256
windows = f.unfold(big_image, kernel_size=256)

print(windows.shape)

Your approach works, but you should enter a stride, otherwise you get a very large output!

# load in a big image
big_image = torch.randn(1, 3, 2048, 2048)

# unfold the image into (b_size, 3 x 256 x 256, num_of_folds)
small_flat_imgs = f.unfold(big_image, kernel_size=256, stride=64)
print(small_flat_imgs.shape)
>> torch.Size([1, 196608, 841])


# test if upper-left corner are identical
print(torch.equal(small_flat_imgs[:,:,0].view(-1, 3, 256, 256), big_image[:, :, :256, :256]))
>> True

small_imgs = small_flat_imgs.view(-1, 3, 256, 256)
print(small_imgs.shape)
>> torch.Size([841, 3, 256, 256])

For F.fold I’m to stupid and would be happy about a answer by ptrblck!

I feel like I should comment to avoid the leaving an incomplete thread.

Ultimately, I just used the windowed reading functionality of Rasterio.

I then glued the windowed outputs back together using np.block(). Probably not the most efficient method but it works well enough for my needs.