Distributed Data Parallel and Cuda graph

I am executing the following code on four gpus.

setup(rank, gpus)

dataset = RandomDataset(input_shape, 80*batch_size, rank)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
data_iter = iter(dataloader)



model = model(pretrained=True).to(rank)

optimizer = optim.SGD(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(s):
    print("[MAKING DDP Model]")
    model = DDP(model)
    print("[MODEL CREATED]")

    for i in range(11):
        optimizer.zero_grad(set_to_none=True)
        inputs, labels = next(data_iter)
        output = model(inputs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

capture_input = torch.empty((batch_size, 3, input_shape, input_shape)).to(rank)
capture_target = torch.argmax(torch.from_numpy(np.eye(1000)[np.random.choice(1000, batch_size)]), axis=1).to(rank)


g = torch.cuda.CUDAGraph()

optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    capture_y_pred = model(capture_input)
    capture_loss = criterion(capture_y_pred, capture_target)
    capture_loss.backward()
optimizer.step()


print("RECORDED")

for i in range(20):
    inputs, label = next(data_iter)
    capture_input.copy_(inputs)
    capture_target.copy_(label)
    g.replay()
    optimizer.step()



print("DATASET DONE")

But I get the following error:
RuntimeError: CUDA error: operation would make the legacy stream depend on a capturing blocking stream
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Is your model working on a single GPU? I see that e.g. the torch.cuda.current_stream().wait_stream(s) is missing before creating the torch.cuda.CUDAGraph() object.
Also, could you add the missing pieces to your code to make it executable (i.e. the actual shape of the inputs etc.)?

1 Like

Thank you for the reply.
I modified the code in the following way and now its working.


import os
import torchvision, torch
from torch.utils.data import Dataset, DataLoader

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.optim as optim
import numpy as np

logging.set_verbosity_error()

os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"


def setup(rank, gpus):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=gpus)

def cleanup():
    dist.destroy_process_group()

class RandomDataset(Dataset):
    def __init__(self, size, length, rank):
        self.len = length
        self.data = torch.randn(length, 3, size, size).to(rank)
        self.output = torch.argmax(torch.from_numpy(np.eye(1000)[np.random.choice(1000, length)]), axis=1).to(rank)

    def __getitem__(self, index):
        return (self.data[index], self.output[index])

    def __len__(self):
        return self.len


def trainer_cv_cuda_graph(rank, input_shape = 224, model_name="vgg19", model= torchvision.models.vgg19, batch_size=64, training_steps = 300, gpus = 4):

    setup(rank, gpus)


    dataset = RandomDataset(input_shape, 80*batch_size, rank)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    data_iter = iter(dataloader)



    model = model(pretrained=True).to(rank)

    optimizer = optim.SGD(model.parameters(), lr=0.0001)
    criterion = torch.nn.CrossEntropyLoss()

    s = torch.cuda.Stream(device=rank)
    s.wait_stream(torch.cuda.current_stream(device=rank))

    with torch.cuda.stream(s):
        print("[MAKING DDP Model]")
        model = DDP(model)
        print("[MODEL CREATED]")

        for i in range(11):
            optimizer.zero_grad(set_to_none=True)
            inputs, labels = next(data_iter)
            output = model(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
    
    capture_input = torch.empty((batch_size, 3, input_shape, input_shape)).to(rank)
    capture_target = torch.argmax(torch.from_numpy(np.eye(1000)[np.random.choice(1000, batch_size)]), axis=1).to(rank)


    g = torch.cuda.CUDAGraph()

    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g, stream=s):
        capture_y_pred = model(capture_input)
        capture_loss = criterion(capture_y_pred, capture_target)
        capture_loss.backward()
    optimizer.step()


    print("RECORDED")

    for i in range(100):
        try:
            inputs, label = next(data_iter)
        except:
            data_iter = iter(dataloader)
            inputs, label = next(data_iter)
        capture_input.copy_(inputs)
        capture_target.copy_(label)
        g.replay()
        optimizer.step()

    
    cleanup()

    print("[CLEANUP] Process {} Done".format(rank))


Good to hear you’ve solved the issue! What exactly did you change that fixed it?

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())

and

with torch.cuda.graph(g):

to

s = torch.cuda.Stream(device=rank)
s.wait_stream(torch.cuda.current_stream(device=rank))

and

 with torch.cuda.graph(g, stream=s):
1 Like