What's recommended way to intergrate FSDP with Customize Tensor Unit

I’m trying to wrap a model which contains special tensors with FSDP, but some error happens. I’ll first show the specific error in case someone else might also run into this.
Refence Repo:

firstly I wrap this SparseLinear as a FSDP Unit, using FULLSHARD, this error will happen:

RuntimeError: setStorage: sizes [1024, 4096], strides [4096, 1], storage offset 12597248, and itemsize 2 requiring a storage sizeof 33583104 are out of bounds for storage of size 0

Then I switch FULLSHARD to SHARD_GRAD_OP, another error show up:

ValueError: expected to be in states [<TrainingState.FORWARD_BACKWARD: 2>] but current state is TrainingSttate.IDLE

eventually, it turns out that, when register hook to the FSDP Unit’s output tensors, if it’s not a torch.Tensor, the pre-backward hook will miss.

now I just make sure that the SparseLinear will output a normal tensor and I do the conversions else where. But I’m seeking for better advice here!