Make specific parameters disappear in state_dict

I have written a custom convolutional layer that has a Linear layer as a member. However, the Linear layer’s weights are just a .view of the weights of the convolutional layer.

class Conv2d(AISBase):
    def __init__(
        in_channels : int,
        out_channels : int,
        kernel_size : _size_2int,
        stride : _size_2int = 1,
        padding : Union[str, _size_2int] = 0,
        dilation : _size_2int = 1,
        groups : int = 1,
        bias : bool = False,
        device = None,
        dtype = None,
        super(Conv2d, self).__init__()
        # ... standard stuff here
        self.weight = Parameter(
            torch.empty((out_channels, in_channels, *self.kernel_size), **factory_kwargs)
        if bias:
            self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
            self.register_parameter('bias', None)
        # ... more stuff

        n_rows = self.kernel_size[0] * self.kernel_size[1] * in_channels
        n_cols = out_channels
        self.linear = Linear(
        # Link the parameters of the conv layer with the one of the linear through a view
        self.linear.weight = Parameter(self.weight.view(self.weight.size(0),-1))
        self.linear.bias = None if not bias else Parameter(self.bias.view(-1))

Viewing the network from outside, this linear layer should not be visible. Therefore, it should not appear when I look at the state_dict of the model. However, I have to use the Linear class and for that the weight has to be a Parameter.
I also want to be able to use a state_dict in load_state_dict that does not have these linear parameters. If I now try to load one, PyTorch of course complains that the Parameters of the self.linear are not matched.

Is there maybe a fancy way through hooks or smth. similar that solves my problem?

Ok, this is resolved now. I added this to my Conv module:

def state_dict_hook(self, destination, prefix, local_metadata):
        prefixes_removed = [k for k in destination if "linear" in k]
        for p in prefixes_removed:
        return destination

Maybe not the most elegant, but it works fine.