Is it possible to execute two models in parallel in two cuda streams?

case1:

input = ...
model1 = ...
model2 = ...

s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()

with torch.cuda.stream(s1):
    output1 = model1(input)
    optimizer1.zero_grad()
    loss1(output1, label).backward()
    optimizer1.step()

with torch.cuda.stream(s2):
    output2 = model2(input)
    optimizer1.zero_grad()
    loss2(output1, label).backward()
    optimizer2.step()

case2:

input = ...
model1 = ...
model2 = ...

s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()

with torch.cuda.stream(s1):
    output1 = model1(input)
    optimizer1.zero_grad()
    loss1(output1, label).backward()
    optimizer1.step()

    output2 = model2(input)
    optimizer1.zero_grad()
    loss2(output1, label).backward()
    optimizer2.step()
    

I expected the first case to take about half as long as the second, but I found that both cases took about the same amount of time.

I don’t know why.

How do I implement parallel execution of multiple models in multiple cuda streams?

Depending on the model and thus the workload, the CPU might not be able to run ahead and schedule the kernel launches fast enough.
You could profile it using e.g. Nsight Systems and check, if the kernels are overlapping or if they are so short, that they are executed “sequentially” on these two devices.
As a quick check you could replace the models with huge matrix multiplications and profile these.

Hi @ptrblck and @ronda,

I have been trying to do something similar with my model. I have created three different copies of the same model and I would like to run them concurrently. Right now, I am running these models on a Jupyter notebook. The structure of the code looks like this:

model0 = GRUCell(output_dim, hidden_dim, batch_size-1, output_dim, num_layers).double().cuda()
model1 = GRUCell(output_dim, hidden_dim, batch_size-1, output_dim, num_layers).double().cuda()
model2 = GRUCell(output_dim, hidden_dim, batch_size-1, output_dim, num_layers).double().cuda()

h0 = model0.init_hidden()
h1 = model1.init_hidden()
h2 = model2.init_hidden()

optimizer0 = torch.optim.Adam(model0.parameters(), lr=learning_rate, weight_decay=0.00000)
optimizer1 = torch.optim.Adam(model1.parameters(), lr=learning_rate, weight_decay=0.00000)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=learning_rate, weight_decay=0.00000)

s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
s3 = torch.cuda.Stream()

loss_fn0 = torch.nn.MSELoss()
loss_fn1 = torch.nn.MSELoss()
loss_fn2 = torch.nn.MSELoss()

# Intermediate code not relevant to the thread hence skipped

for epoch in epochs:

    for k in range(sequence_len):
      #Creating multiple copies of the same data
      x_batch_train0, y_batch_train0, l2, l1 = batch_creator_4((np.array(idxs[0])-k).tolist(), total_len, sequence_len, predict_len, batch_size-1, shift = 2) 
      x_batch_train1, y_batch_train1, _, _ = batch_creator_4((np.array(idxs[0])-k).tolist(), total_len, sequence_len, predict_len, batch_size-1, shift = 0)
      x_batch_train2, y_batch_train2, _, _ = batch_creator_4((np.array(idxs[0])-k).tolist(), total_len, sequence_len, predict_len, batch_size-1, shift = 1)

      tic = time.time()
      with torch.cuda.stream(s1):       
        for i in range(l2):
          for t in range(sequence_len):       
            x_batch_train_sub0 = torch.reshape(x_batch_train0[i,:, t,:], (batch_size-1, 3)).to(device, non_blocking=True)
            output0, h0, _ = model0((x_batch_train_sub0, h0))
          h0 = h0.detach()
          y_batch_train_sub0 = torch.reshape(y_batch_train0[i, :,0], (batch_size-1, 1)).to(device, non_blocking=True)
          loss_tot0 = loss_fn0(output0, y_batch_train_sub0)
          optimizer0.zero_grad()
          loss_tot0.backward()
          optimizer0.step()

      with torch.cuda.stream(s2):       
        for i1 in range(l2):
          for t1 in range(sequence_len):       
            x_batch_train_sub1 = torch.reshape(x_batch_train1[i1,:, t1,:], (batch_size-1, 3)).to(device, non_blocking=True)
            output1, h1, _ = model1((x_batch_train_sub1, h1))
          h1 = h1.detach()
          y_batch_train_sub1 = torch.reshape(y_batch_train1[i1, :,0], (batch_size-1, 1)).to(device, non_blocking=True)
          loss_tot1 = loss_fn1(output1, y_batch_train_sub1)
          optimizer1.zero_grad()
          loss_tot1.backward()
          optimizer1.step()

      with torch.cuda.stream(s3):        
        for i in range(l2):
          for t in range(sequence_len):       
            x_batch_train_sub2 = torch.reshape(x_batch_train2[i,:, t,:], (batch_size-1, 3)).to(device, non_blocking=True)
            output2, h2, _ = model2((x_batch_train_sub2, h2))          
          h2 = h2.detach()
          y_batch_train_sub2 = torch.reshape(y_batch_train2[i, :,0], (batch_size-1, 1)).to(device, non_blocking=True)
          loss_tot2 = loss_fn2(output2, y_batch_train_sub2)
          #loss_arr2.append((loss_tot2**2).cpu().detach().numpy())
          optimizer2.zero_grad()
          loss_tot2.backward()
          optimizer2.step() 

      torch.cuda.synchronize()     
      toc  = time.time()
      print(toc-tic)

For

  • hidden_dim(feature size) = 1000, time taken by one iteration inside the outermost loop(toc-tic) = 0.76s
  • hidden_dim(feature size) = 5000, time taken by one iteration inside the outermost loop(toc-tic) ~ 9s
  • hidden_dim(feature size) = 6000, time taken by one iteration inside the outermost loop(toc-tic) = 13s

Doing the same calculation with serial code as shown below:

# Same model, optimizer and loss creation code as above

# Intermediate code irrelevant to this thread, hence removed.
for epoch in epochs:
    for k in range(sequence_len):
      #pdb.set_trace()
      x_batch_train0, y_batch_train0, l2, l1 = batch_creator_4((np.array(idxs[0])-k).tolist(), total_len, sequence_len, predict_len, batch_size-1, shift = 2)
      x_batch_train1, y_batch_train1, _, _ = batch_creator_4((np.array(idxs[0])-k).tolist(), total_len, sequence_len, predict_len, batch_size-1, shift = 0)
      x_batch_train2, y_batch_train2, _, _ = batch_creator_4((np.array(idxs[0])-k).tolist(), total_len, sequence_len, predict_len, batch_size-1, shift = 1)
      tic  = time.time()
      for i in range(l2):
          for t in range(sequence_len):       

            x_batch_train_sub0 = torch.reshape(x_batch_train0[i,:, t,:], (batch_size-1, 3))..to(device, non_blocking=True)
            x_batch_train_sub1 = torch.reshape(x_batch_train1[i,:, t,:], (batch_size-1, 3))..to(device, non_blocking=True)
            x_batch_train_sub2 = torch.reshape(x_batch_train2[i,:, t,:], (batch_size-1, 3)).to(device, non_blocking=True)          
           
            output0, h0, _ = model0((x_batch_train_sub0, h0))
            output1, h1, _ = model1((x_batch_train_sub1, h1))
            output2, h2, _ = model2((x_batch_train_sub2, h2))
          
          h0 = h0.detach()
          h1 = h1.detach()
          h2 = h2.detach()
          y_batch_train_sub0 = torch.reshape(y_batch_train0[i, :,0], (batch_size-1, 1)).to(device, non_blocking=True)
          y_batch_train_sub1 = torch.reshape(y_batch_train1[i, :,0], (batch_size-1, 1)).to(device, non_blocking=True)
          y_batch_train_sub2 = torch.reshape(y_batch_train2[i, :,0], (batch_size-1, 1)).to(device, non_blocking=True)
          
          loss_tot0 = loss_fn(output0, y_batch_train_sub0)
          loss_tot1 = loss_fn(output1, y_batch_train_sub1)
          loss_tot2 = loss_fn(output2, y_batch_train_sub2)


          optimizer0.zero_grad()
          optimizer1.zero_grad()
          optimizer2.zero_grad()
          loss_tot0.backward()
          loss_tot1.backward()
          loss_tot2.backward()
          optimizer0.step()
          optimizer1.step()
          optimizer2.step()

      toc  = time.time()
      print(toc-tic)
  • For hidden_dim = 1000, I get an average execution time of toc-tic = 0.76s
  • For hidden_dim = 5000, I get an average execution time of toc-tic = 13s
  • For hidden_dim = 6000, I get an average execution time of toc-tic = 19s

This means that the execution is not happening completely parallelly and that the benefits(if any) become visible only for a large network. I have also looked at several threads(1, 2, 3) on this forum and these issues(a, b) on GitHub. I understand there may be plenty of redundant commands that I have used but I wanted to create an example that is simple to explain. I don’t have any experience with Pytorch’s Cuda interface. Any suggestions are welcome.