I’m new to pytorch.
I have roughly 1000 images of size 250*1000, where roughly 60% of the pixel values are nan.
Currently I’m processing these on a CPU with matlab and it’s slower than I wanted.
I’m trying to put the processing on GPU, and using PyTorch tensor was suggested by a friend.
One of the steps that takes long is to apply median filter to each pixel of each slice, if it’s not nan.
Here’s my code:
It turns out the same code on CPU runs only 2 seconds, while GPU takes roughly 30 seconds. Then I’m repeating this 1000 times.
I think there’s something ignorant that I’m doing that’s making the GPU slow.
Based on the posted screenshot I would assume you could use
unfold to create the image patches and call
nanmedian() on these. While this would increase the memory usage, it should also speed up the workload (assuming it can fit into your GPU).
PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier
Thank you for the reply. I’ve tried unfolding the images like below. The unfolded becomes something of size
250*1024*441, then I called
nanmedian() on each patch. which ended up being even slower actually…
I think the double for loops is what’s killing the time
Do you know any way other than
nanmedian() to achieve the same thing?
Is there anyway that I can interpolate the values in the third dimension to fill all the 441 values to eliminate all the NaN values? then I can simply take the middle index… (Found out sorting along third column takes no time on GPU which is fantastic)
Or is there a way to set half of NaN to positive infinity and the other half to negative infinity?
I’m still exploring, in the meantime if you know a quick answer that would be awesome.
kh, kw = 21, 21 # kernel size
dh, dw = 1, 1 # stride
patches = x.unfold(0, kh, dh).unfold(1, kw, dw)
unfold_shape = patches.size()
for i in range(0,250):
for j in range(0,1024):
print ("Time elapsed:", end - start)
P.S. the ``` method looks awesome
Very close now. Given
patches is sorted along the third dimension, the code below gives the index of the median. Now I need to find out how to extract the median from patches given these indices…
My idea was to eliminate the nested for loop, permute the actual patches into a single dimension, and then call
nanmedian with the
dim argument on it.
ohhh I’ve been using it like
tensor.nanmedian() and forgot that it can take dim as input. Thanks this helps alot
Ideally, I want to nicely put all my slices in one tensor, but that would take 400GB of GPU memory. So I’m running things in batches with a for loop.
The first iteration is pretty quick, but the second and onwards are a couple of orders of magnitudes slower than the first…
0.2seconds is nice but 9 seconds for 20 of them is roughly how long it would take on CPU…
I suppose I need to somehow clean up some gpu memory between every batch.
I’ve tried to delete the unfolded tensors with
del Tensor then
torch.cuda.empty_cache(), but this doesn’t help.
empty_cache() will slow down your code further as new memory would need to be allocated via synchronizing
How are you measuring the performance? Often users are not synchronizing the code and since CUDA operations are executed asynchronously, you would need to synchronize the code via
torch.cuda.synchronize() before starting and stopping the timers.
Yeah I wasn’t synchronizing the code. the first iteration time is actually not synchronized…
Here’s my code. Since most of the values in these images are nan, and I don’t need to calculate median for those. Right now I’m calculating median for all then mask the ones I don’t need, which is dumb but it was what I found that produced correct result
How can I calculate median only if they are not nan? while maintaining the original shape and keeping the nan if they are.
#test gpu version
dev = "cuda:0"
dev = "cpu"
device = torch.device(dev)
for i in range(0,40):
start = time. time()
#unfold image patches to third dimension
kh, kw = size_filter*2+1, size_filter*2+1 # kernel size
dh, dw = 1, 1 # stride
patches = S_pad.unfold(0, kh, dh).unfold(1, kw, dw)
patches = torch.flatten(patches, start_dim=3)
#get nan mask
# #calculate median along third dimension
# #set previously nan values back to nan.
print ("Median filter for 20 projections time elapsed:", end - start)```