Unfold uses too much GPU memory and has slow update

I am attempting to implement a cross bilateral filtering like operation into the forward pass of a model I am building. Unfortunately, the unfold operations which I am using to create a sliding window like operation are very memory consuming. This limits my batch size to 4, even though my volumes are only 64 x 64 x 8. Moreover, applying the optimizer is very slow. The update application takes 1.3 seconds per batch, while the forward and backward passes use around 0.05s each. Is there any solution? I am aware of f.unfold(), however it currently only supports 2D images and not 3D volumes.

Here is the code for the forward pass:

def forward(self, x, domain_neighbor):
        
        #Compute guidance image
        x = F.pad(x, (1, 1, 1, 1, 1, 1), mode='constant')
        mat_size = x.shape
        guide_im = self.shared_denoiser(x)#Shared denoiser is a weight shared network
        
        #Compute filter neighborhoods
        guide_im = guide_im.unfold(2, 3, 1).unfold(3, 3, 1).unfold(4, 3, 1).reshape(-1, 1, 3, 3, 3)
        range_neighbor = guide_im - guide_im[:, 0, 1, 1, 1].view(guide_im.shape[0], 1, 1, 1, 1)            
             
        #Estimate filter coeffecients
        domain_kernel = self.domain_coeffecients(domain_neighbor)
        range_kernel = self.range_coeffecients(guide_im)
        
        #Apply bilateral filter
        x = x.unfold(2, 3, 1).unfold(3, 3, 1).unfold(4, 3, 1).reshape(-1, 1, 3, 3, 3)
        weights = domain_kernel*range_kernel
        filtered_pixel = torch.sum(weights*x, dim=[1, 2, 3, 4])/torch.sum(weights, dim=[1, 2, 3, 4])
        
        return filtered_pixel.view(-1, 1, mat_size[2] - 2, mat_size[3] - 2, mat_size[4] - 2)

Appreciate all help!!

Hi,

You seem to be using unfold a lot. It is usually only called once for a convolution operation.
I can see why calling it multiple times in a chained manner would be very memory hungry. I’m sure the size of guide_im and x are quite large no?

Moreover, applying the optimizer is very slow.

That is surprising. Do you have a small code sample (30-40 lines) that reproduces this?

Hey,

I am using unfold mainly because I want to work on overlapping 3D image patches with 3D convolutions. That is why I call it in chains. Is there a way to work on overlapping blocks without loops and unfold?

I will post a code sample a bit later as I am at work.

At some point someone has to write a 3d unfold…

You could use checkpointing to trade more compute for less memory. Regarding speed (of the backward?), I seem to recall discussing how to write a faster and deterministic backward for unfold with someone somewhere, but I don’t know if it ever made it to a patch.

At some point someone has to write a 3d unfold…

Well it was in the original THNN here :wink:
Not sure why it was dropped…