# Apply median filter on image stack with gpu

I’m new to pytorch.

I have roughly 1000 images of size 250*1000, where roughly 60% of the pixel values are nan.

Currently I’m processing these on a CPU with matlab and it’s slower than I wanted.

I’m trying to put the processing on GPU, and using PyTorch tensor was suggested by a friend.

One of the steps that takes long is to apply median filter to each pixel of each slice, if it’s not nan.

Here’s my code:

It turns out the same code on CPU runs only 2 seconds, while GPU takes roughly 30 seconds. Then I’m repeating this 1000 times.

I think there’s something ignorant that I’m doing that’s making the GPU slow.

Based on the posted screenshot I would assume you could use `unfold` to create the image patches and call `nanmedian()` on these. While this would increase the memory usage, it should also speed up the workload (assuming it can fit into your GPU).

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier Thank you for the reply. I’ve tried unfolding the images like below. The unfolded becomes something of size `250*1024*441`, then I called `nanmedian()` on each patch. which ended up being even slower actually…

I think the double for loops is what’s killing the time

Do you know any way other than `nanmedian()` to achieve the same thing?

Is there anyway that I can interpolate the values in the third dimension to fill all the 441 values to eliminate all the NaN values? then I can simply take the middle index… (Found out sorting along third column takes no time on GPU which is fantastic)

Or is there a way to set half of NaN to positive infinity and the other half to negative infinity?

I’m still exploring, in the meantime if you know a quick answer that would be awesome.

``````x=S_pad
kh, kw = 21, 21 # kernel size
dh, dw = 1, 1 # stride
patches = x.unfold(0, kh, dh).unfold(1, kw, dw)
patches=torch.flatten(patches, start_dim=2)
unfold_shape = patches.size()
print(unfold_shape)
patches,indices =torch.sort(patches,dim=2)
start=time.time()
for i in range(0,250):
for j in range(0,1024):
S_med[i+10][j+10]=patches[i,j,:].nanmedian()
end=time.time()
print ("Time elapsed:", end - start)

``````

P.S. the ``` method looks awesome

Very close now. Given ` patches` is sorted along the third dimension, the code below gives the index of the median. Now I need to find out how to extract the median from patches given these indices…

``````mask=~patches.isnan()
``````

My idea was to eliminate the nested for loop, permute the actual patches into a single dimension, and then call `nanmedian` with the `dim` argument on it.

ohhh I’ve been using it like ` tensor.nanmedian()` and forgot that it can take dim as input. Thanks this helps alot

Ideally, I want to nicely put all my slices in one tensor, but that would take 400GB of GPU memory. So I’m running things in batches with a for loop.

The first iteration is pretty quick, but the second and onwards are a couple of orders of magnitudes slower than the first…

i.e.

0.2seconds is nice but 9 seconds for 20 of them is roughly how long it would take on CPU…

I suppose I need to somehow clean up some gpu memory between every batch.
I’ve tried to delete the unfolded tensors with `del Tensor` then `torch.cuda.empty_cache()`, but this doesn’t help.

Any suggestions?

Thanks.

`empty_cache()` will slow down your code further as new memory would need to be allocated via synchronizing `cudaMalloc` calls.
How are you measuring the performance? Often users are not synchronizing the code and since CUDA operations are executed asynchronously, you would need to synchronize the code via `torch.cuda.synchronize()` before starting and stopping the timers.

Yeah I wasn’t synchronizing the code. the first iteration time is actually not synchronized…

Here’s my code. Since most of the values in these images are nan, and I don’t need to calculate median for those. Right now I’m calculating median for all then mask the ones I don’t need, which is dumb but it was what I found that produced correct result

How can I calculate median only if they are not nan? while maintaining the original shape and keeping the nan if they are.

``````#test gpu version
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
device = torch.device(dev)

for i in range(0,40):
torch.cuda.synchronize()
start = time. time()

S_ini=S_correct_tensor[:,:,i*20:(i+1)*20-1]
size_filter=10
#unfold image patches to third dimension
kh, kw = size_filter*2+1, size_filter*2+1 # kernel size
dh, dw = 1, 1 # stride
patches = S_pad.unfold(0, kh, dh).unfold(1, kw, dw)
patches = torch.flatten(patches, start_dim=3)