What does nn.utils.skip_init() really do?

I am working with the Linear4bit class from bitsandbytes here:

Using this snippet

import bitsandbytes as bnb
import torch.nn as nn
import torch

model = nn.utils.skip_init(
            bnb.nn.Linear4bit,
            16,
            16,
            bias=None,
            quant_type='fp4',
            device=torch.device("cpu")
        )

print(model.weight)

model.weight.data = torch.randint(0, 100, [16,16])

model.cuda()

input = torch.randint(0, 16, [16]).float().cuda()

print(model(input))

triggers an error

AttributeError: 'Parameter' object has no attribute 'quant_state'

because the statement model.cuda() does NOT go through this:

which does not happen if I do not use skip_init()
I thought the use of skip_init was just to avoid the uniform/normal init of parameters inside nn.Linear
but I have the feeling it does something else.

here is the reason (on the doc)

  1. The module must not perform any computation on parameters in its constructor except initialization (i.e. functions from torch.nn.init).

so if we use skip_init() we need to manually perform init tasks in Linear4bit().