When creating a sharded model with DTensor, should I expect the parameters shown with model.named_parameters
to be sharded or unsharded? e.g. In the below toy example where we shard some Linear layer over 2 devices, should the shapes produced by printing param.shape
from model.named_parameters()
be sharded or unsharded? Should it change if I remove the all-gather? Trying to get a sense of what to expect when looking at sharded modules. Ideally I’d like to shard e.g. a Linear layer column-wise and be able to see its sharded representation, but I’m not sure if that’s expected behaviour. Thanks!
from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
from torch.distributed import _functional_collectives as funcol
mesh = DeviceMesh("cuda", torch.arange(2))
class FFN(torch.nn.Module):
def __init__(self):
super(FFN, self).__init__()
self.w = torch.nn.Linear(in_features=128, out_features=128)
def forward(self, input):
return self.w(input)
class Network(torch.nn.Module):
def __init__(self, mesh):
super(Network, self).__init__()
self.linear = FFN()
parallelize_plan = {"w": ColwiseParallel()}
self.ffn = parallelize_module(self.linear, mesh, parallelize_plan)
self.ffn_compiled = torch.compile(self.ffn)
def forward(self, input):
return self.ffn_compiled(input)
network = Network(mesh)
network.linear.register_forward_hook(lambda _module, _input, output: funcol.all_gather_tensor(output, -1, list(range(2))))