CrossHair filter in torch.nn.Conv3d()

Hi all,

I built a 3D U-Net using Pytorch. The training time is huge due to the characteristics of the dataset and in an attempt of trying to reduce the running time I am trying to create a CrossHair filter for torch.nn.Conv3d() and torch.nn.ConvTranspose3d() following the idea of DeepVesselNet in order to get less trainable parameters or reduce computation time by having many zeros in the filter.

For that I changed the Conv3d and ConTranspose3d functions from Pytorch as follows:

class CHConv3d(_ConvNd):
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_3_t,
        stride: _size_3_t = 1,
        padding: Union[str, _size_3_t] = 0,
        dilation: _size_3_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        kernel_size_ = _triple(kernel_size)
        stride_ = _triple(stride)
        padding_ = padding if isinstance(padding, str) else _triple(padding)
        dilation_ = _triple(dilation)
        super().__init__(
            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
            False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
        #self.weight.data = self.cross_hair_kernel(in_channels, out_channels, kernel_size_)
        self.weight.data[self.cross_hair_kernel(in_channels, out_channels, kernel_size)!=1]=0.

    def _conv_forward(self, input: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != "zeros":
            return F.conv3d(
                F.pad(
                    input, self._reversed_padding_repeated_twice, mode=self.padding_mode
                ),
                self.weight,
                bias,
                self.stride,
                _triple(0),
                self.dilation,
                self.groups,
            )
        return F.conv3d(
            input, self.weight, bias, self.stride, self.padding, self.dilation, self.groups
        )

    def forward(self, input: Tensor) -> Tensor:
        return self._conv_forward(input, self.bias)
    
    @staticmethod
    def cross_hair_kernel(in_dim, out_dim, size):
        kernel = torch.empty(size)
        center = size[0] // 2
        kernel[center, :, :] = 1
        kernel[:, center, :] = 1
        kernel[:, :, center] = 1
        kernel = kernel.repeat(in_dim, 1, 1, 1)
        return kernel.repeat(out_dim, 1, 1, 1, 1)

class CHConvTranspose3d(_ConvTransposeNd):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_3_t,
        stride: _size_3_t = 1,
        padding: _size_3_t = 0,
        output_padding: _size_3_t = 0,
        groups: int = 1,
        bias: bool = True,
        dilation: _size_3_t = 1,
        padding_mode: str = 'zeros',
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        kernel_size = _triple(kernel_size)
        stride = _triple(stride)
        padding = _triple(padding)
        dilation = _triple(dilation)
        output_padding = _triple(output_padding)
        super().__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            True, output_padding, groups, bias, padding_mode, **factory_kwargs)
        self.weight.data[self.cross_hair_kernel_tr(in_channels, out_channels, kernel_size)!=1]=0.

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 3
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        return F.conv_transpose3d(
            input, self.weight, self.bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)
    
    @staticmethod
    def cross_hair_kernel_tr(in_dim, out_dim, size):
        kernel = torch.empty(size)
        center = size[0] // 2
        kernel[center, :, :] = 1
        kernel[:, center, :] = 1
        kernel[:, :, center] = 1
        kernel = kernel.repeat(out_dim, 1, 1, 1)
        return kernel.repeat(in_dim, 1, 1, 1, 1)

but it seems to take the same running time. Another approach I was not able to carry out is to combine 3 convolutions in which each kernel would be (kernel_size[0], 1, kernel_size[2]), (kernel_size[0], , kernel_size[1], 1), and (1, kernel_size[2], kernel_size[2]), respectively.

If anyone has an idea on how to proceed in Pytorch it would be highly appreciated. Thank you in advance!

Hi MLB!

You won’t save any time by using a full kernel that happens to have a
bunch of zeros in it. Pytorch will still spend as much time multiplying
whatever you’re convolving with those zeros (and getting zeros as the
products) as it would if the kernel elements were not zero.

(Furthermore, if I understand your code correctly, the zeros in your kernel
won’t stay zero – as you train, the optimizer will change those zeros to
non-zero values.)

Some further comments:

Don’t use data. It is deprecated and can lead to errors.

The “approved” way to reinitialize (or otherwise update) weight is to use an
inplace modification such as weight.copy_ (...) or weight[...] = ...
with the modification wrapped in a with torch.no_grad(): block.

torch.empty() gives you a tensor with uninitialized storage, so its values
could start out as anything (including 1.0, which would break your current
logic). You presumably want torch.zeros() here.

This seems to me to be the best way to achieve what you want. You didn’t
say what your issue was with this approach, but it can be made to work.

Note, unless the size of your “cross-hair” kernel is rather large, you might not
get much (if anything) in terms of speed-up using three such kernels instead
of a single full kernel (with lots of zeros), even though the three kernels together
will have fewer parameters than the single full kernel.

Your code is rather layered and complicated, presumably in the interest of
generality. If you have further questions, please post a reasonably-simplified,
fully-self-contained, runnable script that illustrates your issues / questions,
together with the output you obtain when you run it.

Best.

K. Frank

1 Like