I am trying to train a distributed model based on if an instance is captured in some prediction. Basically, if the object of interest is detected in the scan, continue to loss and gradient descent. If not, skip to the next batch step. However, the difficulty is that the model requires find_unused_parameters=True
with torch.nn.parallel.DistributedDataParallel
to accommodate a flow control scheme, and when find_unused_parameters
is True
the model stops at the backward pass after a batch is skipped. How could this possibly be addressed? I’ve seen the static_graph
argument in torch==1.11+
, but it is not clear if that is the best solution or exactly if that is the approach.
I’ve made a simple model here to demonstrate what I am trying to do where I am forcing a skip rather than checking for a variable from training.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast, GradScaler
import os
import numpy as np
class SampleNet(nn.Module):
def __init__(self):
super(SampleNet, self).__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)
# self.unused_conv = nn.Conv2d(in_channels=2, out_channels=12, kernel_size=3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = self.conv(x)
x = self.avg_pool(x).squeeze(-1).squeeze(-1)
return x
class SampleDataset(data.Dataset):
def __init__(self):
super(SampleDataset, self).__init__()
self.rand_samples = np.random.random((4, 1, 32, 32))
self.gt = np.array([1, 1, 1, 1])
def __len__(self):
return 4
def __getitem__(self, sample_idx):
sample_torch = torch.from_numpy(self.rand_samples[sample_idx]).to(torch.float32)
sample_gt_torch = torch.from_numpy(np.array([1.0])).to(torch.float32)
return sample_torch, sample_gt_torch
def main(gpu):
device = f'cuda:{gpu}'
world_size = 2
batch_size = 1
rank = gpu
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank)
net = SampleNet()
net = net.to(device)
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank],
find_unused_parameters=False,
broadcast_buffers=False,
)
scaler = GradScaler()
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3)
epochs = 1
train_dataset = SampleDataset()
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=1,
pin_memory=True, drop_last=True, sampler=train_sampler,
persistent_workers=True, prefetch_factor=1)
print('starting training')
for epoch in range(epochs):
net.train()
train_sampler.set_epoch(epoch)
for sample_index, (sample, target) in enumerate(train_loader):
optimizer.zero_grad()
sample = sample.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
with autocast():
pred_sample = net(sample)
skip_sample = False
if gpu == 1:
if sample_index == 0:
skip_sample = True
if skip_sample == False:
loss = F.binary_cross_entropy_with_logits(input=pred_sample, target=target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print('backward completed: ', gpu)
torch.distributed.barrier()
if gpu == 0:
print('Here "a": ', sample_index)
if gpu == 1:
print('Here "b": ', sample_index)
torch.distributed.barrier()
if gpu == 0:
print('GPU 0 finished.')
if gpu == 1:
print('GPU 1 finished.')
torch.distributed.barrier()
if gpu == 0:
print('GPU 0 after barrier.')
if gpu == 1:
print('GPU 1 after barrier.')
if __name__ == '__main__':
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
mp.set_start_method('spawn')
tot_workers = 2
mp.spawn(main, nprocs=tot_workers)