Very small, stupid question about FSDPParam._init_sharded_param

I wasn’t sure where to ask about this since this is such a minor question about FSDP internals.

In FSDP2’s FSDPParam._init_sharded_param, this function instantiates a nn.Parameter object and assigns it to self.sharded_param, then calls its setter requires_grad_ on it:

# From: https://github.com/pytorch/pytorch/blob/3e0038ae85246c1d3ffd95a618bb8fdebd9dd513/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L398-L399
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)

nn.Parameter.__new__ takes in requires_grad as an argument, defaulting it to True.

Suppose sharded_param is an integer type, for example, torch.uint8 for quantization purposes. Then we first attempt to create a parameter with requires_grad=True in order to set it to false immediately afterwards. This fails, since only float type parameters can have a gradient.

Would self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param), requires_grad=sharded_param.requires_grad) be more accurate, or is there a reason this has to be separate?

I’m asking since I’m experimenting with fsdp2 and qlora and wondering if this is intentional.