Apply median filter on image stack with gpu

I’m new to pytorch.

I have roughly 1000 images of size 250*1000, where roughly 60% of the pixel values are nan.

Currently I’m processing these on a CPU with matlab and it’s slower than I wanted.

I’m trying to put the processing on GPU, and using PyTorch tensor was suggested by a friend.

One of the steps that takes long is to apply median filter to each pixel of each slice, if it’s not nan.

Here’s my code:

It turns out the same code on CPU runs only 2 seconds, while GPU takes roughly 30 seconds. Then I’m repeating this 1000 times.

I think there’s something ignorant that I’m doing that’s making the GPU slow.

Based on the posted screenshot I would assume you could use unfold to create the image patches and call nanmedian() on these. While this would increase the memory usage, it should also speed up the workload (assuming it can fit into your GPU).

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier :wink:

Thank you for the reply. I’ve tried unfolding the images like below. The unfolded becomes something of size 250*1024*441, then I called nanmedian() on each patch. which ended up being even slower actually…

I think the double for loops is what’s killing the time

Do you know any way other than nanmedian() to achieve the same thing?

Is there anyway that I can interpolate the values in the third dimension to fill all the 441 values to eliminate all the NaN values? then I can simply take the middle index… (Found out sorting along third column takes no time on GPU which is fantastic)

Or is there a way to set half of NaN to positive infinity and the other half to negative infinity?

I’m still exploring, in the meantime if you know a quick answer that would be awesome.

kh, kw = 21, 21 # kernel size
dh, dw = 1, 1 # stride
patches = x.unfold(0, kh, dh).unfold(1, kw, dw)
patches=torch.flatten(patches, start_dim=2)
unfold_shape = patches.size()
patches,indices =torch.sort(patches,dim=2)
for i in range(0,250):
    for j in range(0,1024):
print ("Time elapsed:", end - start)

P.S. the ``` method looks awesome

Very close now. Given patches is sorted along the third dimension, the code below gives the index of the median. Now I need to find out how to extract the median from patches given these indices…


My idea was to eliminate the nested for loop, permute the actual patches into a single dimension, and then call nanmedian with the dim argument on it.

ohhh I’ve been using it like tensor.nanmedian() and forgot that it can take dim as input. Thanks this helps alot

Ideally, I want to nicely put all my slices in one tensor, but that would take 400GB of GPU memory. So I’m running things in batches with a for loop.

The first iteration is pretty quick, but the second and onwards are a couple of orders of magnitudes slower than the first…


0.2seconds is nice but 9 seconds for 20 of them is roughly how long it would take on CPU…

I suppose I need to somehow clean up some gpu memory between every batch.
I’ve tried to delete the unfolded tensors with del Tensor then torch.cuda.empty_cache(), but this doesn’t help.

Any suggestions?


empty_cache() will slow down your code further as new memory would need to be allocated via synchronizing cudaMalloc calls.
How are you measuring the performance? Often users are not synchronizing the code and since CUDA operations are executed asynchronously, you would need to synchronize the code via torch.cuda.synchronize() before starting and stopping the timers.

Yeah I wasn’t synchronizing the code. the first iteration time is actually not synchronized…

Here’s my code. Since most of the values in these images are nan, and I don’t need to calculate median for those. Right now I’m calculating median for all then mask the ones I don’t need, which is dumb but it was what I found that produced correct result

How can I calculate median only if they are not nan? while maintaining the original shape and keeping the nan if they are.

#test gpu version
if torch.cuda.is_available():  
    dev = "cuda:0" 
    dev = "cpu"  
device = torch.device(dev) 

for i in range(0,40):
    start = time. time()
    #load proj

    #unfold image patches to third dimension
    kh, kw = size_filter*2+1, size_filter*2+1 # kernel size
    dh, dw = 1, 1 # stride
    patches = S_pad.unfold(0, kh, dh).unfold(1, kw, dw)
    patches = torch.flatten(patches, start_dim=3)

    #get nan mask
    # #calculate median along third dimension
    del patches
    del S_pad
    # #set previously nan values back to nan. 

    print ("Median filter for 20 projections time elapsed:", end - start)```