Variety inference batch size make ValueError Exception

Issue description

Always raise exception when meet specific batch size number in inference phase, such as [9, 11, 13, 15, 19, 22], but single device operating normally.
When downgrade to torch 4.0 stable version still meet this question.

Code example

import sys
import traceback

import torch

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 10, kernel_size=3, padding=1, stride=2)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=1, padding=0, stride=2)
        self.fc = torch.nn.Linear(20, 2)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.nn.functional.adaptive_avg_pool2d(x, 1).squeeze()
        x = self.fc(x)
        return x


problem_pair = []
exception = None

for i in range(8):
    model = Net()
    if i != 0:
        model = torch.nn.DataParallel(model, device_ids=range(i)).cuda()
    else:
        model.cuda()
    for j in range(1, 60):
        try:
            data = torch.rand(j, 3, 8, 8).cuda()
            _ = model(data)
        except Exception as e:
            exception = sys.exc_info()
            problem_pair.append([i, j])
            
print('problem pair {} raise error'.format(problem_pair))
traceback.print_exception(*exception)

output

problem pair [[2, 3], [3, 5], [3, 7], [4, 5], [4, 7], [4, 10], [4, 13], [5, 7], [5, 9], [5, 13], [5, 1
7], [5, 21], [6, 7], [6, 9], [6, 11], [6, 13], [6, 16], [6, 21], [6, 26], [6, 31], [7, 9], [7, 11], [7
, 13], [7, 16], [7, 19], [7, 25], [7, 31], [7, 37], [7, 43]] raise error
Traceback (most recent call last):
  File "demo.py", line 33, in <module>
    _ = model(data)
  File "/home/wangyulong/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in
__call__
    result = self.forward(*input, **kwargs)
  File "/home/wangyulong/.local/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line
115, in forward
    return self.gather(outputs, self.output_device)
  File "/home/wangyulong/.local/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line
127, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/wangyulong/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line
 68, in gather
    return gather_map(outputs)
  File "/home/wangyulong/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line
 55, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/home/wangyulong/.local/lib/python3.5/site-packages/torch/nn/parallel/_functions.py", line 55,
 in forward
    return comm.gather(inputs, ctx.dim, ctx.target_device)
  File "/home/wangyulong/.local/lib/python3.5/site-packages/torch/cuda/comm.py", line 186, in gather
    "but expected {}".format(got, expected))
ValueError: gather got an input of invalid size: got 2, but expected 2x2

System Info

PyTorch version: 0.5.0a0+6eec411
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609
CMake version: version 3.11.0

Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti
GPU 4: GeForce GTX 1080 Ti
GPU 5: GeForce GTX 1080 Ti
GPU 6: GeForce GTX 1080 Ti
GPU 7: GeForce GTX 1080 Ti

Nvidia driver version: 390.30
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy (1.14.3)
[pip3] torch (0.5.0a0+6eec411)
[pip3] torchvision (0.2.1)
[conda] Could not collect

Ok, stupid operation, squeeze operation for batchsize % gpu == 1, make the array from 2d array to 1d array