I want to train an ensemble of NNs on a single GPU in parallel.
Currently I’m doing this:
for model in models:
model.to('cuda')
train_model(model, ...)
Each model is quite small but the GPU utilisation is tiny (3%), which makes me think that the training is happening serially. In fact, if I increase the number of members of the ensemble the training time increases proportionally. This confuses me, as CUDA calls should be asynchronous. Is there any reason why this shouldn’t run in parallel?
I’ve tried to go a step further by taking inspiration from torchensemble using joblib (see code below) but I am not seeing a speedup whatsoever.
def _parallel_fit_per_epoch(
train_loader,
estimator,
optimizer,
criterion,
device,
):
for batch_idx, elem in enumerate(train_loader):
data, target = elem[0].to(device), elem[1].to(device)
optimizer.zero_grad()
output = estimator.forward(data)
loss = criterion(output, target.unsqueeze(-1).float())
loss.backward()
optimizer.step()
return estimator, optimizer, loss
class Regressor(nn.Module):
# etc
class Ensemble(nn.Module):
def __init__(self, num_models, ...):
self.num_models = num_models
self.models = [Regressor(...) for _ in self.num_models]
# etc
def fit(self, train_dataset,valid_dataset,epochs,batch_size, learning_rate,
weight_decay, patience, save_model=False, save_dir=None):
optimizers = []
train_loaders = []
for model in self.models:
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
optimizers.append(opt)
train_loaders.append(train_dataset, batch_size=batch_size, shuffle=True))
with Parallel(n_jobs=self.num_models) as parallel:
# Training loop
for epoch in range(epochs):
rets = parallel(
delayed(_parallel_fit_per_epoch)(
dataloader,
estimator,
optimizer,
self.loss,
self.device,
)
for idx, (estimator, optimizer, dataloader) in enumerate(
zip(self.models, optimizers, train_loaders)
)
)
estimators, optimizers, losses = [], [], []
for estimator, optimizer, loss in rets:
estimators.append(estimator)
optimizers.append(optimizer)
losses.append(loss)
I’m running all this on a slurm cluster, using:
#SBATCH --nodes=1
#SBATCH --mem=120G
#SBATCH --gres=gpu:1
Are there any other directives I should be using?