Parallell_apply is slower than serial application

I am building an ensemble-model in pytorch, and would want to evaluate the models in parallel on different GPUs. I’m using 4 GPUs and evaluate the models with parallell_apply, however, as demonstrated by this code, it’s not really faster than evaluating the models in series. Any ideas as to why that is?

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class parallell_ensambled_net(nn.Module):

    def __init__(self, num_models : int, in_channels=4, num_actions=18):
            super(parallell_ensambled_net, self).__init__()
            self.models = []
            for _ in range(num_models):
                self.models.append(pytorch_net(in_channels=in_channels, num_actions=num_actions))

    def cuda(self, device_ids):
        """ mark moved to cuda """
        self.devices = []
        for model, device_id in zip(self.models, device_ids):
            model.cuda(device=device_id)
            assert next(model.parameters()).is_cuda, 'must be on cuda'
            print('model on device', next(model.parameters()).get_device())
            self.devices.append(next(model.parameters()).get_device())
        self.on_cuda = True

    def forward(self, xes):
        outputs = nn.parallel.parallel_apply(self.models, xes)


class serial_ensambled_net(nn.Module):

    def __init__(self, num_models : int, in_channels=4, num_actions=18):
            super(serial_ensambled_net, self).__init__()
            self.models = []
            for _ in range(num_models):
                self.models.append(pytorch_net(in_channels=in_channels, num_actions=num_actions))

    def cuda(self, device=None):
        """ mark moved to cuda """
        for model in self.models:
            model.cuda(device=device)
            assert next(model.parameters()).is_cuda, 'must be on cuda'
        self.on_cuda = True

    def forward(self, x):
        outputs = [mod(x) for mod in self.models]




class pytorch_net(nn.Module):
  def __init__(self, in_channels=4, num_actions=18):
    super(pytorch_net, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4, padding=2)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
    self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.fc4 = nn.Linear(11 * 11 * 64, 512)
    self.fc5 = nn.Linear(512, num_actions)

    torch.nn.init.xavier_uniform_(self.conv1.weight)
    torch.nn.init.xavier_uniform_(self.conv2.weight)
    torch.nn.init.xavier_uniform_(self.conv3.weight)
    torch.nn.init.xavier_uniform_(self.fc4.weight)
    torch.nn.init.xavier_uniform_(self.fc5.weight)

    torch.nn.init.constant_(self.conv1.bias, 0.0)
    torch.nn.init.constant_(self.conv2.bias, 0.0)
    torch.nn.init.constant_(self.conv3.bias, 0.0)
    torch.nn.init.constant_(self.fc4.bias, 0.0)
    torch.nn.init.constant_(self.fc5.bias, 0.0)

    self.on_cuda = False

  def cuda(self, device=None):
    """ mark moved to cuda """
    super(pytorch_net, self).cuda(device=device)
    self.on_cuda = True

  def forward(self, x):

    x = x.permute(0,3,1,2).float() 
    x /= 255.0
    x = F.relu(self.conv1(x))

    # manual padding, maybe with CUDA
    first_dim = x.shape[:2] + (1, 21)
    zero_pad = torch.zeros(*first_dim)
    if self.on_cuda:
      zero_pad = zero_pad.cuda()
    x = torch.cat((x, zero_pad), 2)
    second_dim = x.shape[:2] + (22, 1)
    zero_pad = torch.zeros(*second_dim)
    if self.on_cuda:
      zero_pad = zero_pad.cuda()
    x = torch.cat((x, zero_pad), 3)

    x = F.relu(self.conv2(x))
    x = F.relu(self.conv3(x))
    x = F.relu(self.fc4(x.view(x.size(0), -1)))
    return self.fc5(x)


NUM = 3
TRIES = 1000

def test_parallell():

    inputs = torch.randn(16, 84, 84, 4)
    model = parallell_ensambled_net(NUM)
    model.cuda([0,1,2,3])

    xes = [inputs.cuda(device=idx) for idx in range(len(model.models))]

    for _ in range(100):
        output = model(xes)


    start = time.time()
    for _ in range(TRIES):
        output = model(xes)
    diff = time.time() - start
    print('parallell took on average', diff/TRIES, 'sec for forward pass')

def test_normal():

    inputs = torch.randn(16, 84, 84, 4).cuda()
    model = pytorch_net()
    model.cuda()

    start = time.time()
    for _ in range(TRIES):
        output = model(inputs)
    diff = time.time() - start
    print('normal took on average', diff/TRIES, 'sec for forward pass')


def test_serial():

    inputs = torch.randn(16, 84, 84, 4).cuda()
    model = serial_ensambled_net(NUM)
    model.cuda()

    start = time.time()
    for _ in range(TRIES):
        output = model(inputs)
    diff = time.time() - start
    print('serial took on average', diff/TRIES, 'sec for forward pass')

test_parallell()
test_normal()
test_serial()

I am also working towards this goal see this. Would you mind sharing any progress you made?