Problem of torch.gesv on multi-GPUs

Hi, I build a model which contains not only CNN for extracting features, but also a torch.gesv function to solve Ax=B. When I run train this model in single GPU(batch size is 20), the time consuming of torch.gesv is 100ms, however, when train this model in 4 GPUs(batch size is 80, each GPU solve 20 sampls), the time consuming of torch.gesv on each GPU is 300ms. This extremely slow down my training process. Why? Can you help me? Thank you!

Hi,

This can be caused by many reasons.
Could you give a minimal example to reproduce this with the same inputs as the one you use and all the timing code please?

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

# pytorch 0.4.1

import pdb
import time
import os,sys

import torch
import torch.nn as nn

dtype = torch.float
device = torch.device("cuda")

class TEST(nn.Module):
  def __init__(self, b):
    super(TEST, self).__init__()
    self.b = b
  def forward(self, A):
    b_cuda = self.b.cuda(device=A.device)
    torch.cuda.synchronize()
    start = time.time()
    X, _ = torch.gesv(b_cuda, A)
    torch.cuda.synchronize()
    print(time.time()-start)
    return X


if __name__ == '__main__':


  os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2"

  A = torch.rand(30, 1024, 1024, device=device, dtype=dtype)
  b = torch.rand(10, 1024, 1)

  model = TEST(b)
  model = nn.DataParallel(model).cuda()

  for _ in range(10):
    z = model(A)


  '''
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"

  A = torch.rand(10, 1024, 1024, device=device, dtype=dtype)
  b = torch.rand(10, 1024, 1)

  model = TEST(b)
  model = nn.DataParallel(model).cuda()

  for _ in range(10):
    z = model(A)

  ''"

Hi, here is my code. When I use multi-gpu, the output time consuming is about 200ms, however, when I use single GPU like below, the output time consuming is about 145ms, I do not know why? Test on Titan X.

Hi,

Is that the time to run the whole script?
Because cuda initialization is going to be slower when more devices are used as well as random data generation.

Oh, no. Just like my code, I write the time test code in forward of the network, and only calculate the time consuming of torch.gesv

Is there any error in our code?

I slightly modified your script to print more things and cleaned up the use of b which is weirdly shared between batch elements in your script.

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

# pytorch 0.4.1

import pdb
import time
import os,sys

import torch
import torch.nn as nn

dtype = torch.float
device = torch.device("cuda")

class TEST(nn.Module):
  def __init__(self):
    super(TEST, self).__init__()
  def forward(self, A, b):
    torch.cuda.synchronize()
    start = time.time()
    X, _ = torch.gesv(b, A)
    torch.cuda.synchronize()
    print("Device: {}. Runtime: {}.".format(torch.cuda.current_device(), time.time()-start))
    return X


if __name__ == '__main__':

  device_string = "0, 1"
  batch_size = 20

  print("""Running with device_string: "{}" and batch_size: {}""".format(device_string, batch_size))

  os.environ["CUDA_VISIBLE_DEVICES"] = device_string

  A = torch.rand(batch_size, 1024, 1024, device=device, dtype=dtype)
  b = torch.rand(batch_size, 1024, 1)

  model = TEST()
  model = nn.DataParallel(model).cuda()

  for _ in range(10):
    z = model(A, b)

This is what I get for different devices:

$ python tmp.py 
Running with device_string: "0, 1" and batch_size: 20
Device: 0. Runtime: 0.983064889908.
Device: 1. Runtime: 1.21471691132.
Device: 0. Runtime: 0.122294902802.
Device: 1. Runtime: 0.263245105743.
Device: 0. Runtime: 0.113092899323.
Device: 1. Runtime: 0.26281785965.
Device: 0. Runtime: 0.11643409729.
Device: 1. Runtime: 0.266459941864.
Device: 0. Runtime: 0.116520881653.
Device: 1. Runtime: 0.266087055206.
Device: 0. Runtime: 0.117300987244.
Device: 1. Runtime: 0.269877910614.
Device: 0. Runtime: 0.116713047028.
Device: 1. Runtime: 0.26772403717.
Device: 0. Runtime: 0.115457057953.
Device: 1. Runtime: 0.262111902237.
Device: 0. Runtime: 0.118778944016.
Device: 1. Runtime: 0.26965379715.
Device: 0. Runtime: 0.116364955902.
Device: 1. Runtime: 0.268661022186.

$ python tmp.py 
Running with device_string: "0" and batch_size: 10
Device: 0. Runtime: 0.369277954102.
Device: 0. Runtime: 0.130491018295.
Device: 0. Runtime: 0.123297214508.
Device: 0. Runtime: 0.120588064194.
Device: 0. Runtime: 0.11813211441.
Device: 0. Runtime: 0.118545055389.
Device: 0. Runtime: 0.119971036911.
Device: 0. Runtime: 0.117920875549.
Device: 0. Runtime: 0.117951869965.
Device: 0. Runtime: 0.125433921814.

$ python tmp.py 
Running with device_string: "1" and batch_size: 10
Device: 0. Runtime: 0.682589054108.
Device: 0. Runtime: 0.257436037064.
Device: 0. Runtime: 0.257730007172.
Device: 0. Runtime: 0.258300065994.
Device: 0. Runtime: 0.257426977158.
Device: 0. Runtime: 0.257858991623.
Device: 0. Runtime: 0.258059978485.
Device: 0. Runtime: 0.258193016052.
Device: 0. Runtime: 0.258168935776.
Device: 0. Runtime: 0.258305072784.

This looks correct to me: each run takes the same time when they are in DataParallel or not.

Once the sync points inside the functions are removed, you can see in the cuda profiler that both ops run at the same time:

Note that the performance drop that you might see when using dataparallel is because your need to copy stuff between gpus (when using a single GPU, no copy is done if the input is already on the right device). These are the brown bars in the timeline above and as you can see they are significant.