Hi all,
Use case:
I am training a model with image data for a semantic segmentation task. I have a requirement to batch data using ratios of combinations of classes. For example: my batch should have 75% of patches containing classes 0 1 and 2, and 25% of classes containing 0 and 1. Hence, I pre-batched my data.
I want to use this pre-batched data with DistributedDataParallel
My question:
Let’s say n_gpus=4
, batch_size=256
, and n_batches=10
. I understand that if I use the following code, and give num_batches
in Dataloader
as 256
It will feed model in each GPU with 256/n_gpus=64
batch of data. Since I want to ensure I have pre-batched data fed into each model, let’s say I give num_batches
in the Dataloader
as 256*n_gpus=1024
, it will do what I want. However, as my n_batches is 10, it won’t equally divide into my GPUs. In the first go, it will take 4 batches and in the second one, it will take another 4. But in the last one, there will be 2 batches left. My batch size is also dynamic.
How should I handle such a case?
-
Option 1: Drop remaining batches. Due to my prerequisite in pre-batching, I already drop lots of data. So I do not want to do this.
-
Option 2: Make sure my batch size will always result in
n_batches%n_gpu = 0
. This is hard (and not optimal) because I have to handle the initial prerequisite of adhering to a certain ratio during the pre-batching. -
Option 3: Train on 4 GPUs in the first two rounds and use only 2 GPUs in the final round. (Is this possible?)
or else, does Pytorch automatically handle it? (I’m new to Pytorch)
The code:
def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=num_gpus,
rank=current_gpu_index,
shuffle=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
shuffle=False,
)
return dataloader
def per_device_launch_fn(current_gpu_index, num_gpu):
# Setup the process groups
setup_device(current_gpu_index, num_gpu)
dataset = get_dataset()
model = get_model()
# prepare the dataloader
dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size)
# Instantiate the torch optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Instantiate the torch loss function
loss_fn = torch.nn.CrossEntropyLoss()
# Put model on device
model = model.to(current_gpu_index)
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[current_gpu_index], output_device=current_gpu_index
)
train_model(ddp_model, dataloader, num_epochs, optimizer, loss_fn)
cleanup()
torch.multiprocessing.start_processes(
per_device_launch_fn,
args=(num_gpu,),
nprocs=num_gpu,
join=True,
start_method="fork",
)