Hi, I am trying to use torch.distributed.all_gather function and I’m confused with the parameter ‘tensor_list’.

here’s my snippet.

```
# output: dictionary, e.g. {'key':'tensor', 'key':'tensor'}
output = self.model(input)
for key in output.keys():
tensor_list = [torch.zeors_like(output[key]) for _ in range(dist.get_world_size())]
output[key] = self.gather_tensor(tensor_list, output[key])
def gather_tensor(self, tensor_list, tensor):
rt = tensor.clone()
dist.all_gather(tensor_list, rt)
return torch.stack(tensor_list)
```

I designed ‘gather_tensor’ following the function ‘reduce_tensor’ from apex example.

However, it shows RuntimeError as below

in gather_tensor, dist.all_gather(tensor_list, rt)

File “/opt/conda/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py”, line 1027, in all_gather

work = _default_pg.allgather([tensor_list], [tensor])

RuntimeError: Input tensor sequence should hsave the same number of tensors as the output tensor sequence

What’s wrong with my code?

```
output = self.model(input)
tensor_list = [torch.zeors_like(output['key1']) for _ in range(dist.get_world_size())]
for key in output.keys():
output[key] = self.gather_tensor(tensor_list, output[key])
```

I put the tensor_list initialization in front of ‘for loop’ and it works but it seems like ‘deadlock’ occurs.