Applying a semantic segmentation classifier to a large image

Hello,

I have trained a semantic segmentation classifier (UNet-variant) and now need to apply the classifier to a to a large stack of climate data. Given the dimensions of the climate data over my study area are much larger than than the classifier input size.

I know I need to break the larger image into smaller chunks and feed them through the classifier. I also know that including some overlap during inference is generally good practice to reduce edge effects in the predictions.

Has someone produced a script that illustrates an efficient way of doing this?

Thank you!

I would try to use nn.Unfold to create the smaller patches and nn.Fold to recreate the bigger output.
Note that nn.Fold will accumulate the values in the overlapping regions, so you might want to use a manual approach with permutations and some view operations.

Let me know, if that would work for you. :slight_smile:

Thanks for the suggestion! Could you please help me fill out an example of how to use unfold and fold in conjunction.

For simplicity sake, let’s say we have an 3x2048x2048 image (my actual images are a lot larger) and a semantic segmentation model that has an input resolution of 256x256.

The following code for example doesn’t produce a tensor of 256x256 windows like I’d expect…which I assume means I am using the unfold function incorrectly:

import torch
import torch.nn.functional as f


# Load in a big image
big_image = torch.rand(1,3,2048,2048)

# Unfold the image into 3 x 256 x 256
windows = f.unfold(big_image, kernel_size=256)

print(windows.shape)

Your approach works, but you should enter a stride, otherwise you get a very large output!

# load in a big image
big_image = torch.randn(1, 3, 2048, 2048)

# unfold the image into (b_size, 3 x 256 x 256, num_of_folds)
small_flat_imgs = f.unfold(big_image, kernel_size=256, stride=64)
print(small_flat_imgs.shape)
>> torch.Size([1, 196608, 841])


# test if upper-left corner are identical
print(torch.equal(small_flat_imgs[:,:,0].view(-1, 3, 256, 256), big_image[:, :, :256, :256]))
>> True

small_imgs = small_flat_imgs.view(-1, 3, 256, 256)
print(small_imgs.shape)
>> torch.Size([841, 3, 256, 256])

For F.fold I’m to stupid and would be happy about a answer by ptrblck!