How to avoid casting DTensor to Tensor before calling a custom operator (a CUDA kernel)

I am doing tensor parallel training with DTensor. In my training code, I have a custom operator gmm, which is linked to a CUDA kernel in the backend. Currently, I am casting DTensor to Tensor before the operation and casting it back afterward, which I find to be the bottleneck in the code. Is there a way I can somehow pass DTensor directly to my custom op and have my custom op interpret the tensor dimension (local dimension) correctly? Doing so naively would lead to a dimension bug; the DTensor dimension is interpreted as the global dimension, not the local one.

        # Casting to local tensor
        Ds_a_dtensor = DTensor.from_local(a_out, device_mesh=self.device_mesh, placements=[Shard(1)])
        Ds_a_gathered = Ds_a_dtensor.redistribute(device_mesh=self.device_mesh, placements=[Replicate()])
        a_b = Ds_a_gathered.to_local().contiguous()
        b_b = torch.cat([module.weight.to_local().T.contiguous() for module in self.b], dim=0).contiguous()
        batch_sizes_b = torch.tensor(self.local_rank, device='cpu', dtype=torch.long)
        
        # Perform custom op
        b_out = ops.gmm(a_b, b_b, batch_sizes_b)

        # Cast back to DTensor
        Ds_b_dtensor = DTensor.from_local(b_out, device_mesh=self.device_mesh, placements=[Shard(1)])
        Ds_b_gathered = Ds_b_dtensor.redistribute(device_mesh=self.device_mesh, placements=[Replicate()])