Performance get worse using multiple GPUs

import torch
import torch.nn as nn
import time

ITER_NUM = 100
BATCH_SIZE = 1000
CHANNEL_NUM = 128
KERNEL_SIZE = 3
INPUT_LENTH = 1000
CPU_NUM = torch.get_num_threads()
GPU_NUM = torch.cuda.device_count()

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Conv1d(CHANNEL_NUM, CHANNEL_NUM, KERNEL_SIZE, groups=CHANNEL_NUM)
    def forward(self, x):
        y = self.net(x)
        return y

class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        self.net = nn.Sequential(
                nn.Conv1d(CHANNEL_NUM, CHANNEL_NUM, KERNEL_SIZE, groups=CHANNEL_NUM),
                nn.Conv1d(CHANNEL_NUM, CHANNEL_NUM, 1)
                )
    def forward(self, x):
        y = self.net(x)
        return y

class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.net = nn.Conv1d(CHANNEL_NUM, CHANNEL_NUM, KERNEL_SIZE)
    def forward(self, x):
        y = self.net(x)
        return y

def test(net, x, desc, iter_num=ITER_NUM):
    ts = time.time()
    torch.cuda.synchronize()
    with torch.no_grad():
        for i in range(iter_num):
            y = net(x)
#     for i in range(iter_num):
#         y = net(x)
#         y.mean().backward()
    torch.cuda.synchronize()
    print(desc, 'total time:', (time.time() - ts), 'for', iter_num, 'iterations')

# one gpu
net = Net().cuda()
net1 = Net1().cuda()
net2 = Net2().cuda()
x = torch.randn([BATCH_SIZE, CHANNEL_NUM, INPUT_LENTH]).cuda()

print('====== start test! gpu num: 1', '======')

test(net, x, '1 gpu depthwise conv')
test(net1, x, '1 gpu dept+pointwise conv')
test(net2, x, '1 gpu naive conv')

# 4 gpus
net = nn.DataParallel(Net()).cuda()
net1 = nn.DataParallel(Net1()).cuda()
net2 = nn.DataParallel(Net2()).cuda()
x = torch.randn([BATCH_SIZE*GPU_NUM, CHANNEL_NUM, INPUT_LENTH]).cuda()

print('====== start test! gpu num:', GPU_NUM, '======')

test(net, x, '{} gpu depthwise conv'.format(GPU_NUM), ITER_NUM//GPU_NUM)
test(net1, x, '{} gpu dept+pointwise conv'.format(GPU_NUM), ITER_NUM//GPU_NUM)
test(net2, x, '{} gpu naive conv'.format(GPU_NUM), ITER_NUM//GPU_NUM)

I was trying to test the performance for normal cnn and depthwise-separable cnn with one and multiple GPUs. However, I found that all the models I tested perform worse on multiple GPUs even I use larger batch size with fewer iterations. Especially for the first one (Net), which only has one depth-wise layer(groups=input_channels), the performance get worse dramatically.

This is the result for inference:

====== start test! gpu num: 1 ======
1 gpu depthwise conv total time: 0.951937198638916 for 100 iterations
1 gpu dept+pointwise conv total time: 4.988993883132935 for 100 iterations
1 gpu naive conv total time: 4.091394424438477 for 100 iterations
====== start test! gpu num: 4 ======
4 gpu depthwise conv total time: 22.152369737625122 for 25 iterations
4 gpu dept+pointwise conv total time: 10.338437795639038 for 25 iterations
4 gpu naive conv total time: 10.148297309875488 for 25 iterations

This is the result for training (commented part in the test function):

====== start test! gpu num: 1 ======
1 gpu depthwise conv total time: 2.49479079246521 for 100 iterations
1 gpu dept+pointwise conv total time: 9.238693952560425 for 100 iterations
1 gpu naive conv total time: 8.14995265007019 for 100 iterations
====== start test! gpu num: 4 ======
4 gpu depthwise conv total time: 27.595251321792603 for 25 iterations
4 gpu dept+pointwise conv total time: 15.855156183242798 for 25 iterations
4 gpu naive conv total time: 15.422184944152832 for 25 iterations

What is the reason?

Since your workload is tiny, you might see the overhead in the communication (scatter and gather for all GPUs).
Could you check the performance for larger models?
Also, have a look at DistributedDataParallel, which might perform better than nn.DataParallel.

Thank you for your answer! I tried DistributedDataParallel and the performance finally makes sense. At least the simplest model(depthwise) can train and inference faster than the two others(dept+pointwise & naive) on multiple GPUs. Here is the result for training:

====== start test! gpu num: 1 ======
1 gpu depthwise conv total time: 2.508331537246704 for 100 iterations
1 gpu dept+pointwise conv total time: 9.373668670654297 for 100 iterations
1 gpu naive conv total time: 8.2487051486969 for 100 iterations
====== start test! gpu num: 4 ======
4 gpu depthwise conv total time: 12.313549995422363 for 25 iterations
4 gpu dept+pointwise conv total time: 13.747057676315308 for 25 iterations
4 gpu naive conv total time: 13.246102333068848 for 25 iterations

Here is the result for inference:

====== start test! gpu num: 1 ======
1 gpu depthwise conv total time: 0.9100768566131592 for 100 iterations
1 gpu dept+pointwise conv total time: 4.834381341934204 for 100 iterations
1 gpu naive conv total time: 4.099367618560791 for 100 iterations
====== start test! gpu num: 4 ======
4 gpu depthwise conv total time: 7.846449375152588 for 25 iterations
4 gpu dept+pointwise conv total time: 8.739933252334595 for 25 iterations
4 gpu naive conv total time: 8.57079792022705 for 25 iterations

As for the overhead in communication, I will embed these models to a larger model and see what happens.