DTensor Give Different Optimized Parameters Compared to Undistributed Version


I am attempting to shard an input tensor for a convolution operation due to the substantial size of the input. In my experiment, my sharded network produces the same output as my unsharded, regular network. However, when I call .backward() and optimizer.step() , I encounter differing optimized parameters().

#unsharded network after optimization
Parameter containing:
tensor([[[[0.5181, 0.4689, 0.5260],
          [0.4717, 0.4164, 0.4764],
          [0.5144, 0.4654, 0.5205]]]], requires_grad=True)

#sharded network after optimization
Parameter containing:
tensor([[[[0.6927, 0.6562, 0.6900],
          [0.4922, 0.4406, 0.5017],
          [0.5139, 0.4579, 0.5114]]]], requires_grad=True)

To validate my sharding methods, I create two networks that should theoretically have the same effect (definition provided at the end). In the sharded version, I’ve implemented a fix_kernel method to copy necessary neighboring pixels to the local tensor, ensuring the same effect as if the convolution were performed on the entire tensor.

Subsequently, I initiate four processes on the CPU, all utilizing the same random seed. Within each process, I create both an unsharded network (SimpleCONV1) and a sharded network (SimpleCONV). The SimpleCONV network is wrapped by a DDP wrapper, as detailed in the following code.

mesh = DeviceMesh("cpu", torch.arange(world_size))
net = SimpleCONV(mesh,2)
net = torch.nn.parallel.DistributedDataParallel(net)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

net1 = SimpleCONV1(mesh,2,net.module.kernel_1)
optimizer1 = torch.optim.SGD(net1.parameters(), lr=0.01)

I defined my input tensor as a simple random tensor and fed it into the network. As expected, the sharded network output tensor_c is identical to the unsharded network output tensor_c1 .

#sharded version
tensor_a = torch.rand(1, 1, 12, 10)
tensor_a = distribute_tensor(tensor_a, mesh, [Shard(dim=2)])
tensor_c = net(tensor_a) #sharded network
tensor_c = DTensor.from_local(tensor_c,mesh,[Shard(2)])

#unsharded version
tensor_a1 = tensor_a.redistribute(mesh,[Replicate()]).to_local().clone()
tensor_c1 = net1(tensor_a1) #unsharded network

However, after calling both optimizer.step() and optimizer1.step(), the two networks yield different optimized parameters. I have already taken into account the averaging mechanism in DDP, so I multiplied the gradient in the sharded network by the world size.

I am puzzled as to what might be causing this problem. Could there be an underlying issue that I’ve overlooked?

#Unsharded Version
class SimpleCONV1(torch.nn.Module):
    def __init__(self,mesh,shard_dim,kernel_1):
        self.kernel_1 = torch.nn.parameter.Parameter(kernel_1.clone())
    def forward(self, x):
        x1 = torch.nn.functional.conv2d(x, self.kernel_1,padding=1)
        return x1
#Sharded version
class SimpleCONV(torch.nn.Module):
    def __init__(self,mesh,shard_dim):
        self.kernel_1 = torch.nn.parameter.Parameter(torch.ones(1, 1, 3, 3))
        self.mesh = mesh
        self.shard_dim = shard_dim
    def forward(self, x):
        x = x.to_local()
        x = self.fix_kernel(x)
        x = torch.nn.functional.conv2d(x, self.kernel_1,padding=self.padding)
        return x
    #I copy necessary part of the tensor to the boundary
    #to have the same effect as if the convolution is done on the whole tensor
    def fix_kernel(self,x):
        missing_part = x[:,:,[0,-1],:]
        missing_part = DTensor.from_local(missing_part,self.mesh,[Shard(self.shard_dim)])
        missing_part = missing_part.redistribute(device_mesh=self.mesh, placements=[Replicate()]).to_local()
        rank = torch.distributed.get_rank()
        if rank == 0:
            missing_part_after = missing_part[:,:,2:3,:]
            #pretend padding
            missing_part_pre = torch.zeros_like(missing_part_after,device=missing_part.device,dtype=missing_part.dtype)
        elif rank == torch.distributed.get_world_size() - 1:
            missing_part_pre = missing_part[:,:,-3:-2,:]
            missing_part_after = torch.zeros_like(missing_part_pre,device=missing_part.device,dtype=missing_part.dtype)
            missing_part_pre = missing_part[:,:,(rank*2 - 1):rank*2,:]
            missing_part_after = missing_part[:, :, (rank * 2+2):(rank * 2+3), :]
        return torch.cat([missing_part_pre,x,missing_part_after],dim=self.shard_dim)
    def get_padding_index(self):
        #2D conv
        padding = [1,1]
        padding[(self.shard_dim - 2)] = 0
        return padding



I’m not totally sure and tagging @wanchaol to take a closer look next week as well- but in this part, i’m wondering why you user .from_local on the output tensor from the sharded network. I would expect the output from the network would already be a DTensor.

Also, I’d double check that torch.sum(tensor_c) does what you expect- does it indeed cause a reduction across shards, or does it perform a local sum? Can you confirm the value of the summed tensor_c from your sharded net matches summed tensor_c1 from unsharded net, before starting backward?

The output from the network is a local tensor. At the very first of my network, I converted a sharded tensor to a local tensor as dtensor is unsupported in torch.nn.functional.conv2d. In my network, each rank takes a small piece of the image, and I manually compensate surrounding pixels to that small piece, so the convolution is like happening to the whole image overall.

I could confirm the two sum are equal as I explicitly did a element-wise comprasion.