I tried the answer in the Stack Overflow and it seems to be faster with a GPU and is slower without a GPU:
import torch
from torch import nn
import numpy as n
class MultiHeadParallel(nn.Module):
def __init__(self, input_dim, output_dim, hidden_size=32, nb_heads=1):
super().__init__()
self.network = nn.Sequential(
nn.Conv1d(in_channels=input_dim * nb_heads, out_channels=hidden_size * nb_heads, kernel_size=1, groups=nb_heads),
nn.Tanh(),
nn.Conv1d(in_channels=hidden_size * nb_heads, out_channels=output_dim * nb_heads, kernel_size=1, groups=nb_heads),
)
self.nb_heads = nb_heads
def forward(self, x):
x = x.repeat(1, self.nb_heads).unsqueeze(-1)
flat = self.network(x)
batch_size = x.shape[0]
return flat.view(batch_size, -1, self.nb_heads)
class MultiHeadNaive(nn.Module):
def __init__(self, input_dim, output_dim, hidden_size=32, nb_heads=1):
super().__init__()
self.networks = nn.ModuleList()
for _ in range(nb_heads):
network = nn.Sequential(
nn.Linear(input_dim, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, output_dim),
)
self.networks.append(network)
def forward(self, x):
outputs = [net(x) for net in self.networks]
return torch.stack(outputs, dim=-1)
BATCH_SIZE = 128
IN_DIM = 256
HIDDEN_SIZE = 256
OUT_DIM = 256
NB_HEADS = 1000
net_parallel = MultiHeadParallel(
input_dim=IN_DIM,
output_dim=OUT_DIM,
hidden_size=HIDDEN_SIZE,
nb_heads=NB_HEADS,
)
net_naive = MultiHeadNaive(
input_dim=IN_DIM,
output_dim=OUT_DIM,
hidden_size=HIDDEN_SIZE,
nb_heads=NB_HEADS,
)
x = torch.randn(BATCH_SIZE, IN_DIM)
print("***Without GPU***")
print("Naive")
%time y_naive = net_naive(x)
print("\nWith Conv1D")
%time y_parallel = net_parallel(x)
print("\n***With GPU***")
net_parallel.cuda()
net_naive.cuda()
x = x.cuda()
print("Naive")
%time y_naive = net_naive(x); torch.cuda.synchronize()
print("\nWith Conv1D")
%time y_parallel = net_parallel(x); torch.cuda.synchronize()
Outputs:
***Without GPU***
Naive
CPU times: user 176 ms, sys: 0 ns, total: 176 ms
Wall time: 176 ms
With Conv1D
CPU times: user 629 ms, sys: 17 µs, total: 629 ms
Wall time: 633 ms
***With GPU***
Naive
CPU times: user 11.2 ms, sys: 3.05 ms, total: 14.2 ms
Wall time: 14.1 ms
With Conv1D
CPU times: user 4.07 ms, sys: 971 µs, total: 5.04 ms
Wall time: 4.86 ms
Edit: added the torch.cuda.synchronize()
call and new timing