Quantize a Factorized Linear Layer

Hi,

I have created a small layer to make my networks smaller as a drop-in replacement for a Linear layer:

class FactorizedLinear(nn.Module):
    def __init__(self,
                 or_linear,
                 dim_ratio=1.0,
                 random_init=False):
        super().__init__()
        self.bias = nn.parameter.Parameter(or_linear.bias.data)
        if random_init:
            u, vh = self.random_init(or_linear.weight.data, dim_ratio=dim_ratio)
            print(f'Doing zero init of tensor {or_linear.weight.shape}, U: {u.shape}, Vh: {vh.shape}')
        else:
            u, vh = self.spectral_init(or_linear.weight.data, dim_ratio=dim_ratio)
            print(f'Doing SVD of tensor {or_linear.weight.shape}, U: {u.shape}, Vh: {vh.shape}')
        self.u = nn.parameter.Parameter(u)
        self.vh = nn.parameter.Parameter(vh)
        self.dim_ratio = dim_ratio
        self.in_features = u.size(0)
        self.out_features = vh.size(1)

    @staticmethod
    @torch.jit.ignore
    def spectral_init(m,
                      dim_ratio=1):
        u, s, vh = torch.linalg.svd(m, full_matrices=False)
        u = u @ torch.diag(torch.sqrt(s))
        vh = torch.diag(torch.sqrt(s)) @ vh
        if dim_ratio < 1:
            dims = int(u.size(1) * dim_ratio)
            u = u[:, :dims]
            vh = vh[:dims, :]
            s_share = s[:dims].sum() / s.sum() * 100
            print(f'SVD eigenvalue share {s_share:.2f}%')
        return u, vh

    @staticmethod
    @torch.jit.ignore
    def random_init(m,
                    dim_ratio=1):
        bottleneck = int(m.size(1) * dim_ratio)
        u = torch.zeros(m.size(0), bottleneck)
        vh = torch.zeros(bottleneck, m.size(1))
        return u, vh

    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, '
                f'out_features={self.out_features}, '
                f'bias=True, dim_ratio={self.dim_ratio}')

    def forward(self, x):
        return x @ (self.u @ self.vh).transpose(0, 1) + self.bias

In practice when I use dim_ratio=0.25 I can achieve a 50% smaller network with slightly worse performance (~10%).

I took a look at this module and this line more or less tells me that you cannot easily extend this class with my logic:

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.quantized.linear(
            x, self._packed_params._packed_params, self.scale, self.zero_point)

I can of course use this self._packed_params = torch.ops.quantized.linear_prepack(weight, bias) to distribute the “small” package, and then just pre-calculate the self.u @ self.vh part and then essentially create the quantized network on-the-fly, but this would make model packaging / distribution very complicated compared to just loading a jit file.

All of this is very experimental and so far I have played only with dim_ratio=0.25, which makes the network itself 2x smaller. But quantization (in my case) usually also provides about 2-3x total module size reduction, so I would need to compress with dim_ratio=0.1 or lower to produce really small networks without quantization.

Maybe I am still missing an elephant in the room, idk. Hope someone will find this useful or have any ideas.

Combining quantization (up to 4x smaller networks) with factorization (2-4x smaller) looks like a killer feature to me.

2 Likes

could you precompute self.u @ self.vh and then transform FactorizedLinear to Linear before quantization? so that it can be quantized to QuantizedLinear and get the size reduction?

1 Like

Yeah, I thought about this. But in our particular case this would make model distribution much more complicated (or it would negate the effects of factorization).

But I fould an elegant solution. I forgot that bias=True is not mandatory in PyTorch. You can just do this and it solves the problem. You can pass the above class to the below class and it will just quantize:

class FactorizedQLinear(nn.Module):
    def __init__(self,
                 f_linear):
        super().__init__()
        self.in_features = f_linear.in_features
        self.out_features = f_linear.out_features
        self.dim_ratio = f_linear.dim_ratio
        self.u_linear = nn.Linear(in_features=f_linear.u.data.size(0),
                                  out_features=f_linear.u.data.size(1),
                                  bias=True)
        self.vh_linear = nn.Linear(in_features=f_linear.vh.data.size(0),
                                   out_features=f_linear.vh.data.size(1),
                                   bias=False)
        self.u_linear.weight.data = f_linear.u.data
        self.u_linear.bias.data = f_linear.bias
        self.vh_linear.weight.data = f_linear.vh.data

    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, '
                f'out_features={self.out_features}, '
                f'bias=True, dim_ratio={self.dim_ratio}')

    def forward(self, x):
        return self.u_linear(self.vh_linear(x))

Also, this method really works. There are some sacrifices in quality (we do s2s), but probably in plain classification it will work just perfectly.

Also quantization used together with factorization can really reduce your model size ~10x, which is nice!

An xsmall_q model here - Quality Benchmarks · snakers4/silero-models Wiki (github.com) - is 30x smaller that large models and it is created by a combination of quantization / minification / factorization.

1 Like