I have written a custom convolutional layer that has a
Linear layer as a member. However, the
Linear layer’s weights are just a
.view of the weights of the convolutional layer.
class Conv2d(AISBase): def __init__( self, in_channels : int, out_channels : int, kernel_size : _size_2int, stride : _size_2int = 1, padding : Union[str, _size_2int] = 0, dilation : _size_2int = 1, groups : int = 1, bias : bool = False, device = None, dtype = None, ): super(Conv2d, self).__init__() # ... standard stuff here self.weight = Parameter( torch.empty((out_channels, in_channels, *self.kernel_size), **factory_kwargs) ) if bias: self.bias = Parameter(torch.empty(out_channels, **factory_kwargs)) else: self.register_parameter('bias', None) # ... more stuff self.reset_parameters() n_rows = self.kernel_size * self.kernel_size * in_channels n_cols = out_channels self.linear = Linear( in_features=n_rows, out_features=n_cols, bias=bias, device=device, dtype=dtype ) # Link the parameters of the conv layer with the one of the linear through a view self.linear.weight = Parameter(self.weight.view(self.weight.size(0),-1)) self.linear.bias = None if not bias else Parameter(self.bias.view(-1))
Viewing the network from outside, this linear layer should not be visible. Therefore, it should not appear when I look at the
state_dict of the model. However, I have to use the Linear class and for that the weight has to be a
I also want to be able to use a
load_state_dict that does not have these linear parameters. If I now try to load one, PyTorch of course complains that the
Parameters of the
self.linear are not matched.
Is there maybe a fancy way through hooks or smth. similar that solves my problem?