Optuna PyTorch DDP DataParallelism mp.spawn trial.report not working

I am new to optuna and was trying a simple ddp example with pytorch where I want to parallelize or use ddp for data parallelism with 2 GPUs. I am trying to use mp.spawn for data parallelisn within each trial and not actually trying to parallelize multiple trials. I have attached my code here. The challenge I am facing is as I pass the “trial” object to the second function “callback”, the trial is not reporting properly and thus not pruning. Can you please suggest to me how I can resolve this?

For your reference, I also included another script without ddp as my reference where pruning is happening and I am trying to replicate the same result (same trials being pruned) with ddp.

ddp_script.py
import os

import logging

import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner

from optuna.integration import PyTorchLightningPruningCallback

import lightning.pytorch as pl

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler

optuna.logging.set_verbosity(optuna.logging.DEBUG)

class MLP(torch.nn.Module):
def init(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
super().init()

    layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
    for _ in range(n_layers):
        layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
    layers.append(nn.Linear(hidden_dim,out_dim))

    self.model = nn.Sequential(*layers)

def forward(self,x):
    return self.model(x)

def setup(rank, world_size):
os.environ[‘MASTER_ADDR’] = ‘localhost’
os.environ[‘MASTER_PORT’] = ‘12355’
dist.init_process_group(“nccl”, rank=rank, world_size=world_size)

def cleanup():
dist.destroy_process_group()

class PruningCallback:
def init(self,trial,monitor=“accuracy”):
self.trial = trial
self.monitor = monitor

def on_epoch_end(self,epoch,metrics):
    value = metrics.get(self.monitor)
    if value is None:
        return
    
    self.trial.report(value,step=epoch)
    if self.trial.should_prune():
        raise optuna.TrialPruned()

def objective(trial):

def objective(rank, world_size, params, callback, return_dict):

setup(rank, world_size)
torch.manual_seed(42)
device = torch.device(f"cuda:{rank}")

in_dim = 10
out_dim = 3
num_train_samples=500
num_val_samples=100
num_epochs = 10
batch_size = 64

train_data = torch.rand(num_train_samples,in_dim).to(device)
val_data = torch.rand(num_val_samples,in_dim).to(device)
train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)

# train_dataset = TensorDataset(train_data,train_targets)
# val_dataset = TensorDataset(val_data,val_targets)
# # print(len(train_dataset),len(val_dataset))

# train_sampler = DistributedSampler(train_dataset,num_replicas=world_size,rank=rank,shuffle=True)
# val_sampler = DistributedSampler(val_dataset,num_replicas=world_size,rank=rank,shuffle=False)
# # print(len(train_sampler),len(val_sampler))

# train_loader = DataLoader(train_dataset,sampler=train_sampler,batch_size=batch_size)
# val_loader = DataLoader(val_dataset,sampler=val_sampler,batch_size=batch_size)
# # print(len(train_loader),len(val_loader))

model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
loss_function = torch.nn.CrossEntropyLoss()

out_dir = "./multirun/mlp-optuna-test"
os.makedirs(out_dir, exist_ok=True)

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    train_outputs = model(train_data)
    loss = loss_function(train_outputs,train_targets)
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_outputs = model(val_data)
        val_predictions = torch.argmax(val_outputs,dim=1)
        val_correct = (val_predictions == val_targets)
        acc = int(val_correct.sum())/len(val_targets)

    # model.train()
    # # train_sampler.set_epoch(epoch)

    # for inputs, targets in train_loader:
    #     inputs, targets = inputs.to(device), targets.to(device)
    #     optimizer.zero_grad()
    #     outputs = model(inputs)
    #     loss = loss_function(outputs,targets)
    #     loss.backward()
    #     optimizer.step()

    # model.eval()
    # correct = 0
    # total = 0
    # with torch.no_grad():
    #     for inputs, targets in val_loader:
    #         inputs, targets = inputs.to(device), targets.to(device)
    #         outputs=model(inputs)
    #         preds = torch.argmax(outputs,dim=1)
    #         correct += (preds == targets).sum().item()
    #         total += targets.size(0)

    # acc = correct/total
    
    acc_tensor = torch.tensor(acc,device=device)
    dist.all_reduce(acc_tensor)
    acc_avg = acc_tensor.item()/world_size
    # print(rank,acc,acc_tensor,acc_avg)

    if rank==0:
        callback.on_epoch_end(epoch,{"accuracy":acc_avg})

    # if rank == 0:
    #     trial.report(acc_avg, epoch)

    #     if trial.should_prune():
    #         cleanup()
    #         raise optuna.exceptions.TrialPruned()

if rank==0:
    return_dict["result"] = acc_avg

cleanup()

def ddp_objective(trial):
params = {
“n_layers”: trial.suggest_int(“n_layers”, 1, 5),
“hidden_dim”: trial.suggest_int(“hidden_dim”, 32, 64),
}

world_size = 2  # Number of GPUs
manager = mp.Manager()
return_dict = manager.dict()

callback = PruningCallback(trial,monitor="accuracy")
mp.spawn(
    objective,
    args=(world_size, params, callback, return_dict),
    nprocs=world_size,
    join=True,
)

return return_dict["result"]

if name == “main”:

sampler = TPESampler(seed=42)
pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
# study = optuna.create_study(direction="minimize", sampler=sampler, pruner=NopPruner())
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
print(f"Pruner:{study.pruner}")
print(f"Sampler:{study.sampler}")

study.optimize(ddp_objective,n_trials=20)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics:")
print(f"  Number of finished trials: {len(study.trials)}")
print(f"  Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
print(f"  Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")

print("  Number of pruned trials ---: ", len(pruned_trials))
print("  Number of complete trials ---: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print(f"  Value: {trial.value}")
print(f"  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

where as script.py is for without ddp
import hydra
from omegaconf import DictConfig
import os
import logging
import torch
import torch.nn as nn

import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner

class MLP(torch.nn.Module):
def init(self,in_dim,out_dim,n_layers,hidden_dim):
super().init()

    layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
    print(f'n_layers:{n_layers}')
    for _ in range(n_layers):
        layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
    layers.append(nn.Linear(hidden_dim,out_dim))

    self.model = nn.Sequential(*layers)

def forward(self,x):
    return self.model(x)

def objective(trial):

n_layers = trial.suggest_int("n_layers",1,5)
hidden_dim = trial.suggest_int("hidden_dim",32,64)

in_dim = 10
out_dim = 3
num_train_samples=500
num_val_samples=100
num_epochs = 10
torch.manual_seed(42)

train_data = torch.rand(num_train_samples,in_dim)
val_data = torch.rand(num_val_samples,in_dim)
train_targets = torch.randint(0,out_dim,(num_train_samples,))
val_targets = torch.randint(0,out_dim,(num_val_samples,))

model = MLP(in_dim,out_dim,n_layers,hidden_dim)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
loss_function = torch.nn.CrossEntropyLoss()

out_dir = "./multirun/mlp-optuna-test"
os.makedirs(out_dir, exist_ok=True)

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    train_outputs = model(train_data)
    loss = loss_function(train_outputs,train_targets)
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_outputs = model(val_data)
        val_predictions = torch.argmax(val_outputs,dim=1)
        val_correct = (val_predictions == val_targets)
        acc = int(val_correct.sum())/len(val_targets)

    trial.report(acc, epoch)

    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()

return acc

if name == “main”:

sampler = TPESampler(seed=42)
pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
# study = optuna.create_study(direction="minimize", sampler=sampler, pruner=NopPruner())
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
print(f"Pruner:{study.pruner}")
print(f"Sampler:{study.sampler}")

study.optimize(objective,n_trials=20)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics:")
print(f"  Number of finished trials: {len(study.trials)}")
print(f"  Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
print(f"  Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")

print("  Number of pruned trials ---: ", len(pruned_trials))
print("  Number of complete trials ---: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print(f"  Value: {trial.value}")
print(f"  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")