Hi,
I am attempting to shard an input tensor for a convolution operation due to the substantial size of the input. In my experiment, my sharded network produces the same output as my unsharded, regular network. However, when I call .backward()
and optimizer.step()
, I encounter differing optimized parameters().
#unsharded network after optimization
Parameter containing:
tensor([[[[0.5181, 0.4689, 0.5260],
[0.4717, 0.4164, 0.4764],
[0.5144, 0.4654, 0.5205]]]], requires_grad=True)
#sharded network after optimization
Parameter containing:
tensor([[[[0.6927, 0.6562, 0.6900],
[0.4922, 0.4406, 0.5017],
[0.5139, 0.4579, 0.5114]]]], requires_grad=True)
To validate my sharding methods, I create two networks that should theoretically have the same effect (definition provided at the end). In the sharded version, I’ve implemented a fix_kernel
method to copy necessary neighboring pixels to the local tensor, ensuring the same effect as if the convolution were performed on the entire tensor.
Subsequently, I initiate four processes on the CPU, all utilizing the same random seed. Within each process, I create both an unsharded network (SimpleCONV1) and a sharded network (SimpleCONV). The SimpleCONV network is wrapped by a DDP wrapper, as detailed in the following code.
mesh = DeviceMesh("cpu", torch.arange(world_size))
net = SimpleCONV(mesh,2)
net = torch.nn.parallel.DistributedDataParallel(net)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
optimizer.zero_grad()
net1 = SimpleCONV1(mesh,2,net.module.kernel_1)
optimizer1 = torch.optim.SGD(net1.parameters(), lr=0.01)
optimizer1.zero_grad()
I defined my input tensor as a simple random tensor and fed it into the network. As expected, the sharded network output tensor_c
is identical to the unsharded network output tensor_c1
.
#sharded version
tensor_a = torch.rand(1, 1, 12, 10)
tensor_a = distribute_tensor(tensor_a, mesh, [Shard(dim=2)])
tensor_c = net(tensor_a) #sharded network
tensor_c = DTensor.from_local(tensor_c,mesh,[Shard(2)])
torch.sum(tensor_c).backward()
#unsharded version
tensor_a1 = tensor_a.redistribute(mesh,[Replicate()]).to_local().clone()
tensor_c1 = net1(tensor_a1) #unsharded network
torch.sum(tensor_c1).backward()
However, after calling both optimizer.step()
and optimizer1.step()
, the two networks yield different optimized parameters. I have already taken into account the averaging mechanism in DDP, so I multiplied the gradient in the sharded network by the world size.
I am puzzled as to what might be causing this problem. Could there be an underlying issue that I’ve overlooked?
#Unsharded Version
class SimpleCONV1(torch.nn.Module):
def __init__(self,mesh,shard_dim,kernel_1):
super().__init__()
self.kernel_1 = torch.nn.parameter.Parameter(kernel_1.clone())
def forward(self, x):
x1 = torch.nn.functional.conv2d(x, self.kernel_1,padding=1)
return x1
#Sharded version
class SimpleCONV(torch.nn.Module):
def __init__(self,mesh,shard_dim):
super().__init__()
self.kernel_1 = torch.nn.parameter.Parameter(torch.ones(1, 1, 3, 3))
self.mesh = mesh
self.shard_dim = shard_dim
def forward(self, x):
x = x.to_local()
x = self.fix_kernel(x)
x = torch.nn.functional.conv2d(x, self.kernel_1,padding=self.padding)
return x
#I copy necessary part of the tensor to the boundary
#to have the same effect as if the convolution is done on the whole tensor
def fix_kernel(self,x):
missing_part = x[:,:,[0,-1],:]
missing_part = DTensor.from_local(missing_part,self.mesh,[Shard(self.shard_dim)])
missing_part = missing_part.redistribute(device_mesh=self.mesh, placements=[Replicate()]).to_local()
rank = torch.distributed.get_rank()
if rank == 0:
missing_part_after = missing_part[:,:,2:3,:]
#pretend padding
missing_part_pre = torch.zeros_like(missing_part_after,device=missing_part.device,dtype=missing_part.dtype)
elif rank == torch.distributed.get_world_size() - 1:
missing_part_pre = missing_part[:,:,-3:-2,:]
missing_part_after = torch.zeros_like(missing_part_pre,device=missing_part.device,dtype=missing_part.dtype)
else:
missing_part_pre = missing_part[:,:,(rank*2 - 1):rank*2,:]
missing_part_after = missing_part[:, :, (rank * 2+2):(rank * 2+3), :]
return torch.cat([missing_part_pre,x,missing_part_after],dim=self.shard_dim)
def get_padding_index(self):
#2D conv
padding = [1,1]
padding[(self.shard_dim - 2)] = 0
return padding
I
Thanks