I am building an ensemble-model in pytorch, and would want to evaluate the models in parallel on different GPUs. I’m using 4 GPUs and evaluate the models with parallell_apply, however, as demonstrated by this code, it’s not really faster than evaluating the models in series. Any ideas as to why that is?
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class parallell_ensambled_net(nn.Module):
def __init__(self, num_models : int, in_channels=4, num_actions=18):
super(parallell_ensambled_net, self).__init__()
self.models = []
for _ in range(num_models):
self.models.append(pytorch_net(in_channels=in_channels, num_actions=num_actions))
def cuda(self, device_ids):
""" mark moved to cuda """
self.devices = []
for model, device_id in zip(self.models, device_ids):
model.cuda(device=device_id)
assert next(model.parameters()).is_cuda, 'must be on cuda'
print('model on device', next(model.parameters()).get_device())
self.devices.append(next(model.parameters()).get_device())
self.on_cuda = True
def forward(self, xes):
outputs = nn.parallel.parallel_apply(self.models, xes)
class serial_ensambled_net(nn.Module):
def __init__(self, num_models : int, in_channels=4, num_actions=18):
super(serial_ensambled_net, self).__init__()
self.models = []
for _ in range(num_models):
self.models.append(pytorch_net(in_channels=in_channels, num_actions=num_actions))
def cuda(self, device=None):
""" mark moved to cuda """
for model in self.models:
model.cuda(device=device)
assert next(model.parameters()).is_cuda, 'must be on cuda'
self.on_cuda = True
def forward(self, x):
outputs = [mod(x) for mod in self.models]
class pytorch_net(nn.Module):
def __init__(self, in_channels=4, num_actions=18):
super(pytorch_net, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4, padding=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.fc4 = nn.Linear(11 * 11 * 64, 512)
self.fc5 = nn.Linear(512, num_actions)
torch.nn.init.xavier_uniform_(self.conv1.weight)
torch.nn.init.xavier_uniform_(self.conv2.weight)
torch.nn.init.xavier_uniform_(self.conv3.weight)
torch.nn.init.xavier_uniform_(self.fc4.weight)
torch.nn.init.xavier_uniform_(self.fc5.weight)
torch.nn.init.constant_(self.conv1.bias, 0.0)
torch.nn.init.constant_(self.conv2.bias, 0.0)
torch.nn.init.constant_(self.conv3.bias, 0.0)
torch.nn.init.constant_(self.fc4.bias, 0.0)
torch.nn.init.constant_(self.fc5.bias, 0.0)
self.on_cuda = False
def cuda(self, device=None):
""" mark moved to cuda """
super(pytorch_net, self).cuda(device=device)
self.on_cuda = True
def forward(self, x):
x = x.permute(0,3,1,2).float()
x /= 255.0
x = F.relu(self.conv1(x))
# manual padding, maybe with CUDA
first_dim = x.shape[:2] + (1, 21)
zero_pad = torch.zeros(*first_dim)
if self.on_cuda:
zero_pad = zero_pad.cuda()
x = torch.cat((x, zero_pad), 2)
second_dim = x.shape[:2] + (22, 1)
zero_pad = torch.zeros(*second_dim)
if self.on_cuda:
zero_pad = zero_pad.cuda()
x = torch.cat((x, zero_pad), 3)
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.fc4(x.view(x.size(0), -1)))
return self.fc5(x)
NUM = 3
TRIES = 1000
def test_parallell():
inputs = torch.randn(16, 84, 84, 4)
model = parallell_ensambled_net(NUM)
model.cuda([0,1,2,3])
xes = [inputs.cuda(device=idx) for idx in range(len(model.models))]
for _ in range(100):
output = model(xes)
start = time.time()
for _ in range(TRIES):
output = model(xes)
diff = time.time() - start
print('parallell took on average', diff/TRIES, 'sec for forward pass')
def test_normal():
inputs = torch.randn(16, 84, 84, 4).cuda()
model = pytorch_net()
model.cuda()
start = time.time()
for _ in range(TRIES):
output = model(inputs)
diff = time.time() - start
print('normal took on average', diff/TRIES, 'sec for forward pass')
def test_serial():
inputs = torch.randn(16, 84, 84, 4).cuda()
model = serial_ensambled_net(NUM)
model.cuda()
start = time.time()
for _ in range(TRIES):
output = model(inputs)
diff = time.time() - start
print('serial took on average', diff/TRIES, 'sec for forward pass')
test_parallell()
test_normal()
test_serial()