How to efficiently subsample from large images


I’m new to Pytorch and deep learning in general. I’m developing a bacterial cell segmentation tool for microscopy with Pytorch/Unet. Since bacterial cells are very small (~1 micron wide x 3 microns long), they are only 20 or so pixels wide and I can’t simply load my images (1460 x 1936 pixels ) and scale them down without losing critical information. Instead, we’ve been subsampling regions small enough that can be run on a gpu in batches of 16 or so (160 x 160 pixels) and stitching the model predictions together at the end to get the full image mask. I’m wondering if there is a better/more established method for this kind of subsampling.

Any help would be greatly appreciated.

Here’s some example code:

#import raw image datasets and resize if necessary
if len(training_data) == len(training_mask): #check list of training and label images
    train_set_size = len(training_data)
    X_data = np.zeros([train_set_size,1460,1936])
    Y_data = np.zeros([train_set_size,1460,1936])
n = 0    
for X_file, y_file, in zip(training_data,training_mask):
    #Load raw images .tiff .tif 
    if X_file[-4:] == '.tif' or X[-4:] == 'tiff':
        X_image_load = np.array(,X_file)))
    if X_image_load.shape == (728, 968):
        X_image_for_stack = Image.fromarray(X_image_load).resize((1936, 1460))
    elif X_image_load.shape == (1460, 1936):
        X_image_for_stack = X_image_load
        print(X_file + 'has wrong size')
    X_data[n,:,:] = np.array(X_image_for_stack) #stack raw images
    y_image_load = np.array(,y_file)))
    if y_image_load.shape == (1460, 1936):
        y_image_for_stack = np.array(y_image_load)
    elif y_image_load.shape == (728, 968):
        y_image_for_stack = Image.fromarray(y_image_load).resize((1936, 1460))
    elif y_image_load.shape == (1456, 1936):
        y_image_for_stack = Image.fromarray(y_image_load).resize((1936, 1460))
    Y_data[n,:,:] = np.array(y_image_for_stack) #stack raw label images (binary masks)
    n += 1

X = []
Y = []
Coverage = 10
image_width = 1936
image_height = 1460
sample_px = 160
n_samples = int(Coverage * image_width * image_height  / (sample_px**2))
#randomly select points to subsample from
H = np.random.randint(int(sample_px/2) + 1,image_height - int(sample_px/2), n_samples) #height sampling
W = np.random.randint(int(sample_px/2) + 1,image_width -int(sample_px/2), n_samples) #width sampling
Stack_idx = np.random.randint(0, train_set_size, n_samples) #image stack sampling
Transp = np.random.randint(0, 2, n_samples) #randomly transpose images
Blur = np.random.randint(0, 2, n_samples) #randomly blur images
#loop over all n_samples for random subsampling and append 160 x 160 array to stacked set of data and labels
for i in range(n_samples):
    tmp_X = X_data[Stack_idx[i],H[i]-int(sample_px/2):H[i]+int(sample_px/2),W[i]-sample_px/2:W[i]+sample_px/2]
    tmp_X = (tmp_X-np.min(tmp_X))/np.max(tmp_X) #normalize data between 0 and 1
    if Transp[i]:
        tmp_X = tmp_X.T #transpose
    if Blur[i] and i > 5000 and (i%20 == 1):
        tmp_X = gaussian_filter(tmp_X,np.random.randint(2, 5)) #blur for later in training epoch
    tmp_Y = Y_data[Stack_idx[i],H[i]-int(sample_px/2):H[i]+int(sample_px/2),W[i]-sample_px/2:W[i]+sample_px/2]
    tmp_Y = tmp_Y/255 #normalize data between 0 and 1
    if Transp[i]:
        tmp_Y = tmp_Y.T #transpose
    i += 1

x_train_tensor = torch.Tensor(X)
y_train_tensor = torch.Tensor(Y)
dataset = TensorDataset(x_train_tensor, y_train_tensor)

You could probably use unfold directly to create these patches.
Have a look at this post for a simple example.

Thanks, that post was very helpful. I’m still a little confused about how I should be feeding the patches into my model for prediction and putting the predictions back together.

x = torch.randn(1, 500, 500)  # batch, c, h, w
k = 160 # kernel size
d = 160# stride
patches = x.unfold(1, k, d).unfold(2, k, d)
unfold_shape = patches.size()
patches = patches.contiguous().view(-1, 1,k, k) #my model takes tensors of shape (1,1,160,160)
data =,batch_size=2,num_workers=0)
for batch in data:
   output = model(batch)
#collect outputs

You won’t be able to create the same input shape using the specified shapes, as the input shape of 500 divided by the kernel size of 160 will have a remainder.
If that’s important, have a look at this code, which adds padding to the input.

Thanks for the clarification. The input shape above was arbitrary, but padding will still be useful for my actual images. My question was a bit more basic, relating to the use of the trained model on images for segmentation. If I’m starting with an image, unfolding it, getting model predictions, and then piecing the original shape back together, what is the best way to keep track of each patch, if we assume they’re not overlapping (for now)?

My original code started with an empty numpy array with the shape of the original image and filled it with the patches, but it seems like this isn’t the most efficient way to process the images. I ultimately want to use the to feed all of the patches into the model, but I can’t quite understand how to keep track of their location in the original image.

My old code for getting patches:

if image.shape == (1460, 1936): #actual image shape
        x_points = np.uint16(np.linspace(80, 1460-80, 10)) #split x into 10 slices
        y_points = np.uint16(np.linspace(80, 1936-80, 12)) #split y into 12 slices
        output_mask = np.zeros([1460, 1936]) #create placeholder
        for x_coords in tqdm(x_points): #loop over x and y coordinates and fill in the placeholder
            for y_coords in y_points:
                tmp_data = image[x_coords-80:x_coords+80,y_coords-80:y_coords+80]
                tmp_data = (tmp_data-np.min(tmp_data))/np.max(tmp_data)
                X = torch.from_numpy(tmp_data)
                X = X.view(1,1,160,160).to(device, dtype = torch.float)
                prediction = model(X)
                tmp_out = prediction.cpu().detach()[0][1]>threshold
                output_mask[x_coords-80:x_coords+80,y_coords-80:y_coords+80] = tmp_out

Does it make sense to convert from a tensor back to a numpy array for each patch?
The code below gives me a padded tensor of size (108, 1, 160, 160)

x = torch.randn(1, 1460, 1936)  # batch, h, w
k = 160 # kernel size
d = 160# stride
x = F.pad(x,(x.size(2)%k // 2,x.size(2)%k // 2,
             x.size(1)%k // 2,x.size(1)%k // 2))
patches = x.unfold(1, k, d).unfold(2, k, d) 
unfold_shape = patches.size()
patches = patches.contiguous().view(-1, 1,k, k) #(108, 1, 160, 160)

Since I can’t run all patches through at once, I need to batch, collect predictions, and then unfold. That’s what I don’t understand conceptually how to do with Pytorch. Thanks again for your help.

I’ve figured out how to reconstruct the patches after model prediction with the code below. I’m not sure if the padding calculations were correct in the linked post, but they were giving me half of remainder for padding, rather than the amount to pad to a multiple of 160.

x0 = torch.randn(1460, 1936)  # h, w
# kernel size
k = 160 
# stride
d = 160
hpad = (k-x0.size(0)%k) // 2 
wpad = (k-x0.size(1)%k) // 2 
#pad x0
x = F.pad(x0,(wpad,wpad,hpad,hpad)) 
patches = x.unfold(0, k, d).unfold(1, k, d) 
unfold_shape = patches.size()
#reshape to (batch,1,h,w)
patches = patches.contiguous().view(-1, 1,k, k)
#create storage tensor
temp = torch.empty(patches.shape) 

#loop over all patches, feed model predictions back into storage tensor
for i,patch in tqdm(enumerate(patches)):
    temp[i,:,:,:] = model(patch.view(-1,1,k,k).to(device, dtype = torch.float)).cpu().detach()[0][1]

# Reshape back
patches_orig = temp.view(unfold_shape)
output_h = unfold_shape[0] * unfold_shape[2]
output_w = unfold_shape[1] * unfold_shape[3]
patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
patches_orig = patches_orig.view(output_h, output_w)
#slice away padding
reconstructed_image = patches_orig[hpad:patches_orig.shape[0]-hpad,wpad:patches_orig.shape[1]-wpad]

The last thing I’m stuck on is getting the index from the batch loader to reconstruct the patched image. This post has a lot of discussion on the topic, but not any simple answers.

Wouldn’t the provided code snippet automatically reconstruct the image using the correct patches?
Also, if you put all patches into the batch dimension, you could avoid the for loop, which should be slower than a single forward pass.

Yes, the snippet does reconstruct the patched image, albeit one patch at a time. So you’re suggesting something like the following for batch processing:

reshaped = patches.view(-1,k,k) #(batch,h,w)
temp = torch.empty(reshaped.shape)
loader = Dataloader(dataset=reshaped,batch_size=16)
for batch in loader:
   x = batch.view(-1,1,k,k).to(device, dtype = torch.float)
   temp[batch_idx,:,:] = model(x)

I don’t understand how to get the patch index from each batch to fill the empty tensor (temp). Can you clarify how that should work? Thanks!

I have the same problem, and could you solve it?

I want to patch image using the unfold function, the commands as follows,

img = io.imread('02.png')
img = img[np.newaxis, :]
img = (np.pad(img, ((0, 0), (26, 26), (26, 26)), mode='constant')/255.).astype(np.float32)
img = torch.unsqueeze(torch.from_numpy(img), dim=0)

kc, kh, kw = 1, 64, 64  # kernel size
dc, dh, dw = 1, 64, 64  # stride

patches = img.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.contiguous().view(patches.size(0), -1, kh, kw)

and get the patched image as follow using the below commands,

fig, ax = plt.subplots(figsize=(8, 8), nrows=8, ncols=8)
plt.subplots_adjust(hspace=0.02, wspace=0.005)
for i, axes in enumerate(ax.ravel()):
    axes.imshow(patches[0, i]*70, vmin=0, vmax=70, cmap='pyart_NWSRef')

then calculate convolution using nn.Conv2d,

patches_conv = nn.Conv2d(64, 64, 3, 1, 1)(patches)

fig, ax = plt.subplots(figsize=(8, 8), nrows=8, ncols=8)
plt.subplots_adjust(hspace=0.02, wspace=0.005)
for i, axes in enumerate(ax.ravel()):
    axes.imshow(patches_conv[0, i].detach().numpy()*70, vmin=0, vmax=70, cmap='pyart_NWSRef')

It does not reconstruct the patched model prediction image. It does not works using the F.fold function to reconstruct the image, and get the same image.

So I’m very confused, could you help me?

Sorry for disturb you. I want to unfold the input image to feed into the model and fold the patched model output back. Could you help me? @ptrblck and the original image is below.


Thanks so much!