I wasn’t sure where to ask about this since this is such a minor question about FSDP internals.
In FSDP2’s FSDPParam._init_sharded_param
, this function instantiates a nn.Parameter
object and assigns it to self.sharded_param
, then calls its setter requires_grad_
on it:
# From: https://github.com/pytorch/pytorch/blob/3e0038ae85246c1d3ffd95a618bb8fdebd9dd513/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L398-L399
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)
nn.Parameter.__new__
takes in requires_grad
as an argument, defaulting it to True
.
Suppose sharded_param
is an integer type, for example, torch.uint8
for quantization purposes. Then we first attempt to create a parameter with requires_grad=True
in order to set it to false immediately afterwards. This fails, since only float type parameters can have a gradient.
Would self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param), requires_grad=sharded_param.requires_grad)
be more accurate, or is there a reason this has to be separate?
I’m asking since I’m experimenting with fsdp2 and qlora and wondering if this is intentional.