I have written a custom convolutional layer that has a Linear
layer as a member. However, the Linear
layer’s weights are just a .view
of the weights of the convolutional layer.
class Conv2d(AISBase):
def __init__(
self,
in_channels : int,
out_channels : int,
kernel_size : _size_2int,
stride : _size_2int = 1,
padding : Union[str, _size_2int] = 0,
dilation : _size_2int = 1,
groups : int = 1,
bias : bool = False,
device = None,
dtype = None,
):
super(Conv2d, self).__init__()
# ... standard stuff here
self.weight = Parameter(
torch.empty((out_channels, in_channels, *self.kernel_size), **factory_kwargs)
)
if bias:
self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
else:
self.register_parameter('bias', None)
# ... more stuff
self.reset_parameters()
n_rows = self.kernel_size[0] * self.kernel_size[1] * in_channels
n_cols = out_channels
self.linear = Linear(
in_features=n_rows,
out_features=n_cols,
bias=bias,
device=device,
dtype=dtype
)
# Link the parameters of the conv layer with the one of the linear through a view
self.linear.weight = Parameter(self.weight.view(self.weight.size(0),-1))
self.linear.bias = None if not bias else Parameter(self.bias.view(-1))
Viewing the network from outside, this linear layer should not be visible. Therefore, it should not appear when I look at the state_dict
of the model. However, I have to use the Linear class and for that the weight has to be a Parameter
.
I also want to be able to use a state_dict
in load_state_dict
that does not have these linear parameters. If I now try to load one, PyTorch of course complains that the Parameter
s of the self.linear
are not matched.
Is there maybe a fancy way through hooks or smth. similar that solves my problem?