I am doing tensor parallel training with DTensor. In my training code, I have a custom operator gmm, which is linked to a CUDA kernel in the backend. Currently, I am casting DTensor to Tensor before the operation and casting it back afterward, which I find to be the bottleneck in the code. Is there a way I can somehow pass DTensor directly to my custom op and have my custom op interpret the tensor dimension (local dimension) correctly? Doing so naively would lead to a dimension bug; the DTensor dimension is interpreted as the global dimension, not the local one.
# Casting to local tensor
Ds_a_dtensor = DTensor.from_local(a_out, device_mesh=self.device_mesh, placements=[Shard(1)])
Ds_a_gathered = Ds_a_dtensor.redistribute(device_mesh=self.device_mesh, placements=[Replicate()])
a_b = Ds_a_gathered.to_local().contiguous()
b_b = torch.cat([module.weight.to_local().T.contiguous() for module in self.b], dim=0).contiguous()
batch_sizes_b = torch.tensor(self.local_rank, device='cpu', dtype=torch.long)
# Perform custom op
b_out = ops.gmm(a_b, b_b, batch_sizes_b)
# Cast back to DTensor
Ds_b_dtensor = DTensor.from_local(b_out, device_mesh=self.device_mesh, placements=[Shard(1)])
Ds_b_gathered = Ds_b_dtensor.redistribute(device_mesh=self.device_mesh, placements=[Replicate()])