I am new to optuna and was trying a simple ddp example with pytorch where I want to parallelize or use ddp for data parallelism with 2 GPUs. I am trying to use mp.spawn for data parallelisn within each trial and not actually trying to parallelize multiple trials. I have attached my code here. The challenge I am facing is as I pass the “trial” object to the second function “callback”, the trial is not reporting properly and thus not pruning. Can you please suggest to me how I can resolve this?
For your reference, I also included another script without ddp as my reference where pruning is happening and I am trying to replicate the same result (same trials being pruned) with ddp.
ddp_script.py
import os
import logging
import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner
from optuna.integration import PyTorchLightningPruningCallback
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler
optuna.logging.set_verbosity(optuna.logging.DEBUG)
class MLP(torch.nn.Module):
def init(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
super().init()
layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
for _ in range(n_layers):
layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
layers.append(nn.Linear(hidden_dim,out_dim))
self.model = nn.Sequential(*layers)
def forward(self,x):
return self.model(x)
def setup(rank, world_size):
os.environ[‘MASTER_ADDR’] = ‘localhost’
os.environ[‘MASTER_PORT’] = ‘12355’
dist.init_process_group(“nccl”, rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class PruningCallback:
def init(self,trial,monitor=“accuracy”):
self.trial = trial
self.monitor = monitor
def on_epoch_end(self,epoch,metrics):
value = metrics.get(self.monitor)
if value is None:
return
self.trial.report(value,step=epoch)
if self.trial.should_prune():
raise optuna.TrialPruned()
def objective(trial):
def objective(rank, world_size, params, callback, return_dict):
setup(rank, world_size)
torch.manual_seed(42)
device = torch.device(f"cuda:{rank}")
in_dim = 10
out_dim = 3
num_train_samples=500
num_val_samples=100
num_epochs = 10
batch_size = 64
train_data = torch.rand(num_train_samples,in_dim).to(device)
val_data = torch.rand(num_val_samples,in_dim).to(device)
train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)
# train_dataset = TensorDataset(train_data,train_targets)
# val_dataset = TensorDataset(val_data,val_targets)
# # print(len(train_dataset),len(val_dataset))
# train_sampler = DistributedSampler(train_dataset,num_replicas=world_size,rank=rank,shuffle=True)
# val_sampler = DistributedSampler(val_dataset,num_replicas=world_size,rank=rank,shuffle=False)
# # print(len(train_sampler),len(val_sampler))
# train_loader = DataLoader(train_dataset,sampler=train_sampler,batch_size=batch_size)
# val_loader = DataLoader(val_dataset,sampler=val_sampler,batch_size=batch_size)
# # print(len(train_loader),len(val_loader))
model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
loss_function = torch.nn.CrossEntropyLoss()
out_dir = "./multirun/mlp-optuna-test"
os.makedirs(out_dir, exist_ok=True)
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
train_outputs = model(train_data)
loss = loss_function(train_outputs,train_targets)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_outputs = model(val_data)
val_predictions = torch.argmax(val_outputs,dim=1)
val_correct = (val_predictions == val_targets)
acc = int(val_correct.sum())/len(val_targets)
# model.train()
# # train_sampler.set_epoch(epoch)
# for inputs, targets in train_loader:
# inputs, targets = inputs.to(device), targets.to(device)
# optimizer.zero_grad()
# outputs = model(inputs)
# loss = loss_function(outputs,targets)
# loss.backward()
# optimizer.step()
# model.eval()
# correct = 0
# total = 0
# with torch.no_grad():
# for inputs, targets in val_loader:
# inputs, targets = inputs.to(device), targets.to(device)
# outputs=model(inputs)
# preds = torch.argmax(outputs,dim=1)
# correct += (preds == targets).sum().item()
# total += targets.size(0)
# acc = correct/total
acc_tensor = torch.tensor(acc,device=device)
dist.all_reduce(acc_tensor)
acc_avg = acc_tensor.item()/world_size
# print(rank,acc,acc_tensor,acc_avg)
if rank==0:
callback.on_epoch_end(epoch,{"accuracy":acc_avg})
# if rank == 0:
# trial.report(acc_avg, epoch)
# if trial.should_prune():
# cleanup()
# raise optuna.exceptions.TrialPruned()
if rank==0:
return_dict["result"] = acc_avg
cleanup()
def ddp_objective(trial):
params = {
“n_layers”: trial.suggest_int(“n_layers”, 1, 5),
“hidden_dim”: trial.suggest_int(“hidden_dim”, 32, 64),
}
world_size = 2 # Number of GPUs
manager = mp.Manager()
return_dict = manager.dict()
callback = PruningCallback(trial,monitor="accuracy")
mp.spawn(
objective,
args=(world_size, params, callback, return_dict),
nprocs=world_size,
join=True,
)
return return_dict["result"]
if name == “main”:
sampler = TPESampler(seed=42)
pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
# study = optuna.create_study(direction="minimize", sampler=sampler, pruner=NopPruner())
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
print(f"Pruner:{study.pruner}")
print(f"Sampler:{study.sampler}")
study.optimize(ddp_objective,n_trials=20)
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics:")
print(f" Number of finished trials: {len(study.trials)}")
print(f" Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
print(f" Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")
print(" Number of pruned trials ---: ", len(pruned_trials))
print(" Number of complete trials ---: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
where as script.py is for without ddp
import hydra
from omegaconf import DictConfig
import os
import logging
import torch
import torch.nn as nn
import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner
class MLP(torch.nn.Module):
def init(self,in_dim,out_dim,n_layers,hidden_dim):
super().init()
layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
print(f'n_layers:{n_layers}')
for _ in range(n_layers):
layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
layers.append(nn.Linear(hidden_dim,out_dim))
self.model = nn.Sequential(*layers)
def forward(self,x):
return self.model(x)
def objective(trial):
n_layers = trial.suggest_int("n_layers",1,5)
hidden_dim = trial.suggest_int("hidden_dim",32,64)
in_dim = 10
out_dim = 3
num_train_samples=500
num_val_samples=100
num_epochs = 10
torch.manual_seed(42)
train_data = torch.rand(num_train_samples,in_dim)
val_data = torch.rand(num_val_samples,in_dim)
train_targets = torch.randint(0,out_dim,(num_train_samples,))
val_targets = torch.randint(0,out_dim,(num_val_samples,))
model = MLP(in_dim,out_dim,n_layers,hidden_dim)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
loss_function = torch.nn.CrossEntropyLoss()
out_dir = "./multirun/mlp-optuna-test"
os.makedirs(out_dir, exist_ok=True)
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
train_outputs = model(train_data)
loss = loss_function(train_outputs,train_targets)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_outputs = model(val_data)
val_predictions = torch.argmax(val_outputs,dim=1)
val_correct = (val_predictions == val_targets)
acc = int(val_correct.sum())/len(val_targets)
trial.report(acc, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return acc
if name == “main”:
sampler = TPESampler(seed=42)
pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
# study = optuna.create_study(direction="minimize", sampler=sampler, pruner=NopPruner())
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
print(f"Pruner:{study.pruner}")
print(f"Sampler:{study.sampler}")
study.optimize(objective,n_trials=20)
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics:")
print(f" Number of finished trials: {len(study.trials)}")
print(f" Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
print(f" Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")
print(" Number of pruned trials ---: ", len(pruned_trials))
print(" Number of complete trials ---: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")