Hi all,
I built a 3D U-Net using Pytorch. The training time is huge due to the characteristics of the dataset and in an attempt of trying to reduce the running time I am trying to create a CrossHair filter for torch.nn.Conv3d() and torch.nn.ConvTranspose3d() following the idea of DeepVesselNet in order to get less trainable parameters or reduce computation time by having many zeros in the filter.
For that I changed the Conv3d and ConTranspose3d functions from Pytorch as follows:
class CHConv3d(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
padding: Union[str, _size_3_t] = 0,
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size_ = _triple(kernel_size)
stride_ = _triple(stride)
padding_ = padding if isinstance(padding, str) else _triple(padding)
dilation_ = _triple(dilation)
super().__init__(
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
#self.weight.data = self.cross_hair_kernel(in_channels, out_channels, kernel_size_)
self.weight.data[self.cross_hair_kernel(in_channels, out_channels, kernel_size)!=1]=0.
def _conv_forward(self, input: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
return F.conv3d(
F.pad(
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
self.weight,
bias,
self.stride,
_triple(0),
self.dilation,
self.groups,
)
return F.conv3d(
input, self.weight, bias, self.stride, self.padding, self.dilation, self.groups
)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.bias)
@staticmethod
def cross_hair_kernel(in_dim, out_dim, size):
kernel = torch.empty(size)
center = size[0] // 2
kernel[center, :, :] = 1
kernel[:, center, :] = 1
kernel[:, :, center] = 1
kernel = kernel.repeat(in_dim, 1, 1, 1)
return kernel.repeat(out_dim, 1, 1, 1, 1)
class CHConvTranspose3d(_ConvTransposeNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
padding: _size_3_t = 0,
output_padding: _size_3_t = 0,
groups: int = 1,
bias: bool = True,
dilation: _size_3_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
output_padding = _triple(output_padding)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
self.weight.data[self.cross_hair_kernel_tr(in_channels, out_channels, kernel_size)!=1]=0.
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 3
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
return F.conv_transpose3d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@staticmethod
def cross_hair_kernel_tr(in_dim, out_dim, size):
kernel = torch.empty(size)
center = size[0] // 2
kernel[center, :, :] = 1
kernel[:, center, :] = 1
kernel[:, :, center] = 1
kernel = kernel.repeat(out_dim, 1, 1, 1)
return kernel.repeat(in_dim, 1, 1, 1, 1)
but it seems to take the same running time. Another approach I was not able to carry out is to combine 3 convolutions in which each kernel would be (kernel_size[0], 1, kernel_size[2]), (kernel_size[0], , kernel_size[1], 1), and (1, kernel_size[2], kernel_size[2]), respectively.
If anyone has an idea on how to proceed in Pytorch it would be highly appreciated. Thank you in advance!