How can I more efficiently reduce noise in the gradient by using a continuous relaxation of the gradient for max pooling layers?

The code I have below was adapted from code written for the original TensorFlow 1.x, and thus is probably not as efficient as it could be in PyTorch. When it replaces the MaxPool2d layers in a model, it reduces the noise in the gradient by using a continuous relaxation of the gradient for max pooling.

The issue is that it seems to be really slow and can result in out of memory errors (especially when using CUDA), but I need it or something like it to calculate attributions for specific spatial positions. I am a bit out of my depth here with this issue and thus any help would be appreciated!

class MaxPool2dSmoothed(torch.nn.Module):
    """
    Pooling layer where, if we backprop through it,
    gradients get allocated proportional to the input activation.
    Then we backprop through that instead.
    
    #https://github.com/tensorflow/lucid/blob/master/lucid/optvis/overrides/smoothed_maxpool_grad.py
    #https://colab.research.google.com/github/tensorflow/lucid/blob/master/notebooks/building-blocks/AttrSpatial.ipynb
    """

    def __init__(
        self,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Optional[Union[int, Tuple[int, ...]]] = None,
        padding: Union[int, Tuple[int, ...]] = 0,
        ceil_mode: bool = False,
    ) -> None:
        super().__init__()
        self.avgpool = torch.nn.AvgPool2d(
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            ceil_mode=ceil_mode,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_avg = self.avgpool(x ** 2) / (1e-2 + self.avgpool(torch.abs(x)))

        x_smoothed = torch.autograd.grad(
            outputs=[x_avg],
            inputs=[x],
            grad_outputs=[torch.ones_like(x_avg)],
            create_graph=True,
        )[0]
        return x_smoothed


pool = MaxPool2dSmoothed(kernel_size=3, stride=2, padding=0, ceil_mode=True)
test_input = torch.nn.Parameter(torch.randn(1, 6, 64, 64))
test_output = pool(test_input.clone())