Size mimatch error in distributed training

Hi, I have a problem with starting a distributed training with pytorch.

Following is the sample code to repro this issue:

from __future__ import print_function
import os
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.utils.data.distributed


if __name__ == "__main__":
  # Initialize the model
  model = nn.Sequential(
      nn.Linear(20, 10),
      nn.Linear(10, 20)
  )

  torch.distributed.init_process_group(world_size=2, \
    init_method='file:///' + os.path.join(os.environ['HOME'], 'distributedFile'), \
    backend='gloo')
  model.cuda()
  model = nn.parallel.DistributedDataParallel(model)
  print(4)

  for epoch in range(5):
    total_loss = 0
    for idx in range(5):
      in_val = torch.zeros(20)
      in_val[idx] = 1.0
      print(model)
      output = model(in_val)

The error msg is:

Traceback (most recent call last):
  File "dist_test.py", line 36, in <module>
    output = model(in_val)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 216, in forward
    outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 223, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 65, in parallel_apply
    raise output
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 41, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/container.py", line 91, in forward
    input = module(input)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 55, in forward
    return F.linear(input, self.weight, self.bias)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/functional.py", line 994, in linear
    output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [1 x 10], m2: [20 x 10] at /opt/conda/conda-bld/pytorch_1524586445097/work/aten/src/THC/generic/THCTensorMathBlas.cu:249
terminate called after throwing an instance of 'gloo::EnforceNotMet'
  what():  [enforce fail at /opt/conda/conda-bld/pytorch_1524586445097/work/third_party/gloo/gloo/cuda.cu:249] error == cudaSuccess. 29 vs 0. Error at: /opt/conda/conda-bld/pytorch_1524586445097/work/third_party/gloo/gloo/cuda.cu:249: driver shutting down

Any feedback or suggestions would be appreciated.