I am attempting to implement a cross bilateral filtering like operation into the forward pass of a model I am building. Unfortunately, the unfold operations which I am using to create a sliding window like operation are very memory consuming. This limits my batch size to 4, even though my volumes are only 64 x 64 x 8. Moreover, applying the optimizer is very slow. The update application takes 1.3 seconds per batch, while the forward and backward passes use around 0.05s each. Is there any solution? I am aware of f.unfold()
, however it currently only supports 2D images and not 3D volumes.
Here is the code for the forward pass:
def forward(self, x, domain_neighbor):
#Compute guidance image
x = F.pad(x, (1, 1, 1, 1, 1, 1), mode='constant')
mat_size = x.shape
guide_im = self.shared_denoiser(x)#Shared denoiser is a weight shared network
#Compute filter neighborhoods
guide_im = guide_im.unfold(2, 3, 1).unfold(3, 3, 1).unfold(4, 3, 1).reshape(-1, 1, 3, 3, 3)
range_neighbor = guide_im - guide_im[:, 0, 1, 1, 1].view(guide_im.shape[0], 1, 1, 1, 1)
#Estimate filter coeffecients
domain_kernel = self.domain_coeffecients(domain_neighbor)
range_kernel = self.range_coeffecients(guide_im)
#Apply bilateral filter
x = x.unfold(2, 3, 1).unfold(3, 3, 1).unfold(4, 3, 1).reshape(-1, 1, 3, 3, 3)
weights = domain_kernel*range_kernel
filtered_pixel = torch.sum(weights*x, dim=[1, 2, 3, 4])/torch.sum(weights, dim=[1, 2, 3, 4])
return filtered_pixel.view(-1, 1, mat_size[2] - 2, mat_size[3] - 2, mat_size[4] - 2)
Appreciate all help!!