Due to my use case, I need to perform permute operation on some intermediate tensor in my model. When using single gpu it runs fine. But when I’m using DDP training mode, it throws stride mismatch runtime error
-- Process 3 terminated with the following error:
Traceback (most recent call last):
File "/home/demo.py", line 37, in train
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank], output_device=rank)
File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
_verify_param_shape_across_processes(self.process_group, parameters)
File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
return dist._verify_params_across_processes(process_group, tensors, logger)
RuntimeError: params[0] in this process with sizes [3, 1, 3, 3] appears not to match strides of the same param in process 0.
This is the minimum code to reproduce this error
import os
import torch.nn as nn
import torch.utils.data
import torch.distributed as dist
import torch.utils.data.distributed
import torch.multiprocessing as mp
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
def cleanup():
dist.destroy_process_group()
class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim)
def forward(self, x):
out_p = self.conv(x.permute(0, 3, 1, 2))
return out_p
def train(rank, world_size):
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
net = Model(3).cuda()
if rank == 0:
out = net(torch.rand((1, 32, 32, 3)).cuda())
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank], output_device=rank)
cleanup()
def main():
world_size = 4
mp.spawn(train,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"
main()
if i do not run inference on rank 0 once , it does not throw any errors.
I also find out that if i use default conv2d instead of grouped conv2d, it also runs fine without throwing any errors.
It kinda confused me that the stride of the param of this conv2d is somehow changed during one inference? If i understand the error correctly.