I don’t believe your use case is easily doable since gradients will be calculated for the entire parameter and not for subtensors.
You could check this approach which uses subtensors with a frozen and trainable part and might fit your use case. I’ve also posted a way to replace built-in layers with this custom module using torch.fx
in case that’s helpful.