Parallel execution of modules in nn.ModuleList

I tried the answer in the Stack Overflow and it seems to be faster with a GPU and is slower without a GPU:

import torch
from torch import nn
import numpy as n

class MultiHeadParallel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size=32, nb_heads=1):
        super().__init__()

        self.network = nn.Sequential(
            nn.Conv1d(in_channels=input_dim * nb_heads, out_channels=hidden_size * nb_heads, kernel_size=1, groups=nb_heads),
            nn.Tanh(),
            nn.Conv1d(in_channels=hidden_size * nb_heads, out_channels=output_dim * nb_heads, kernel_size=1, groups=nb_heads),
        )
        self.nb_heads = nb_heads
      
    def forward(self, x):
      x = x.repeat(1, self.nb_heads).unsqueeze(-1)
      flat = self.network(x)
      batch_size = x.shape[0]
      return flat.view(batch_size, -1, self.nb_heads)

class MultiHeadNaive(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size=32, nb_heads=1):
        super().__init__()

        self.networks = nn.ModuleList()
        for _ in range(nb_heads):
            network = nn.Sequential(
                nn.Linear(input_dim, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, output_dim),
            )
            self.networks.append(network)

    def forward(self, x):
      outputs = [net(x) for net in self.networks]
      return torch.stack(outputs, dim=-1)

BATCH_SIZE = 128
IN_DIM = 256
HIDDEN_SIZE = 256
OUT_DIM = 256
NB_HEADS = 1000

net_parallel = MultiHeadParallel(
    input_dim=IN_DIM,
    output_dim=OUT_DIM,
    hidden_size=HIDDEN_SIZE,
    nb_heads=NB_HEADS,
)

net_naive = MultiHeadNaive(
    input_dim=IN_DIM,
    output_dim=OUT_DIM,
    hidden_size=HIDDEN_SIZE,
    nb_heads=NB_HEADS,
)
x = torch.randn(BATCH_SIZE, IN_DIM)

print("***Without GPU***")
print("Naive")
%time y_naive = net_naive(x)
print("\nWith Conv1D")
%time y_parallel = net_parallel(x)

print("\n***With GPU***")
net_parallel.cuda()
net_naive.cuda()
x = x.cuda()
print("Naive")
%time y_naive = net_naive(x); torch.cuda.synchronize()
print("\nWith Conv1D")
%time y_parallel = net_parallel(x); torch.cuda.synchronize()

Outputs:

***Without GPU***
Naive
CPU times: user 176 ms, sys: 0 ns, total: 176 ms
Wall time: 176 ms

With Conv1D
CPU times: user 629 ms, sys: 17 µs, total: 629 ms
Wall time: 633 ms

***With GPU***
Naive
CPU times: user 11.2 ms, sys: 3.05 ms, total: 14.2 ms
Wall time: 14.1 ms

With Conv1D
CPU times: user 4.07 ms, sys: 971 µs, total: 5.04 ms
Wall time: 4.86 ms

Edit: added the torch.cuda.synchronize() call and new timing

2 Likes