Partial cuda graphs are slowere than original model?

I am trying to make the model fast using cuda graphs. I have parts of model that cannot be graphed. I tried to benchmark resnet from pytorch and resnet in mmdet3d models.

import time
import torch
from mmdet3d.models import ResNet

iterations = 100

model_name = 'resnet152'
mmdet_model_size = 152

model = ResNet(mmdet_model_size)  # torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=False)
model.train()
model = model.cuda()

criterion = torch.nn.MSELoss()  # torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

static_input = torch.randn(2, 3, 224, 224).cuda()
static_target = torch.randn(2, 256, 56, 56).cuda()
# static_target = torch.randint(0, 1000, (2,)).long().cuda()

x = time.time()
for i in range(iterations):
    optimizer.zero_grad(set_to_none=True)
    out = model(static_input)[0]
    loss = criterion(out, static_target)
    loss.backward()
    optimizer.step()

print(f"ungraphed: total time {iterations} iterations: ", time.time() - x)
model.cpu()
del model
torch.cuda.synchronize()
torch.cuda.empty_cache()

###################################################################
# parital graph v1
#######################################################

model = ResNet(mmdet_model_size)  # torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=False)
model.train()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

for param_group in optimizer.param_groups:
    param_group['capturable'] = True
    print("Capturable:", param_group['capturable'])

model = torch.cuda.make_graphed_callables(model, (static_input,))
# warmup
for i in range(iterations):
    optimizer.zero_grad(set_to_none=True)
    out = model(static_input)[0]
    loss = criterion(out, static_target)
    loss.backward()
    optimizer.step()

x = time.time()
for i in range(iterations):
    optimizer.zero_grad(set_to_none=True)
    out = model(static_input)[0]
    loss = criterion(out, static_target)
    loss.backward()
    optimizer.step()
print(f"partial graph: total time {iterations} iterations: ", time.time() - x)
model.cpu()
del model
torch.cuda.synchronize()
torch.cuda.empty_cache()

###################################################################
# Fully Graphed
#######################################################

model = ResNet(mmdet_model_size)  # torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=False)
model.train()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

for param_group in optimizer.param_groups:
    param_group['capturable'] = True
    print("Capturable:", param_group['capturable'])

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        y_pred = model(static_input)[0]
        loss = criterion(y_pred, static_target)
        loss.backward()
        optimizer.step()
torch.cuda.current_stream().wait_stream(s)

g = torch.cuda.CUDAGraph()

# Trace the forward and backward pass within the graph
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    static_y_pred = model(static_input)[0]
    static_loss = criterion(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()

# Real inputs for replaying the graph
real_inputs = [torch.zeros_like(static_input) for _ in range(iterations)]
real_targets = [torch.zeros_like(static_target) for _ in range(iterations)]

# Measure time for graphed version
x_graphed = time.time()
for data, target in zip(real_inputs, real_targets):
    static_input.copy_(data)
    static_target.copy_(target)
    g.replay()  # Replay the traced graph

print(f"graphed:total time {iterations} iterations: {time.time() - x_graphed}")

Above is my code to test both models. In case I use simple Resnet, I see speedup in both partial and complete cuda graph case.

# paste results here fore resnet
ungraphed: total time 100 iterations:  7.005798816680908
Using cache found in /home/irdali.durrani/.cache/torch/hub/pytorch_vision_v0.10.0
partial graph: total time 100 iterations:  3.2496914863586426
Using cache found in /home/irdali.durrani/.cache/torch/hub/pytorch_vision_v0.10.0
graphed:total time 100 iterations: 1.4217324256896973

In case of mmdet3d resnet the results are below. Why is partial graph becoming slower than the original model?

ungraphed: total time 100 iterations:  1.8167531490325928
Capturable: True
partial graph: total time 100 iterations:  3.173295259475708
Capturable: True
graphed:total time 100 iterations: 0.34721994400024414

CUDA operations are executed asynchronously so you would need to synchronize the code before starting and stopping host timers for a proper profiling.

Even after calling cuda synchronize before starting and stopping timers yields the same results. Below is the update code and output.

import warnings

warnings.simplefilter(action='ignore')

import time
import torch
from mmdet3d.models import ResNet


iterations = 200

mmdet_model_size = 152

model = ResNet(mmdet_model_size)  #
model.train()
model = model.cuda()

criterion = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

static_input = torch.randn(2, 3, 224, 224).cuda()
static_target = torch.randn(2, 256, 56, 56).cuda()  # for mmdet3d model

torch.cuda.synchronize()
x = time.time()
for i in range(iterations):
    optimizer.zero_grad(set_to_none=True)
    out = model(static_input)[0]
    loss = criterion(out, static_target)
    loss.backward()
    optimizer.step()
torch.cuda.synchronize()
print(f"ungraphed: total time {iterations} iterations: ", time.time() - x)
model.cpu()
del model
torch.cuda.synchronize()
torch.cuda.empty_cache()

###################################################################
# parital graph v1
#######################################################

model = ResNet(mmdet_model_size)
model.train()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

for param_group in optimizer.param_groups:
    param_group['capturable'] = True

model = torch.cuda.make_graphed_callables(model, (static_input,))
# warmup
for i in range(11):
    optimizer.zero_grad(set_to_none=True)
    out = model(static_input)[0]
    loss = criterion(out, static_target)
    loss.backward()
    optimizer.step()

torch.cuda.synchronize()
x = time.time()
for i in range(iterations):
    optimizer.zero_grad(set_to_none=True)
    out = model(static_input)[0]
    loss = criterion(out, static_target)
    loss.backward()
    optimizer.step()
torch.cuda.synchronize()
print(f"partial graph: total time {iterations} iterations: ", time.time() - x)
model.cpu()
del model
torch.cuda.synchronize()
torch.cuda.empty_cache()

###################################################################
# Fully Graphed
#######################################################

model = ResNet(mmdet_model_size)
model.train()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

for param_group in optimizer.param_groups:
    param_group['capturable'] = True

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        y_pred = model(static_input)[0]
        loss = criterion(y_pred, static_target)
        loss.backward()
        optimizer.step()
torch.cuda.current_stream().wait_stream(s)

g = torch.cuda.CUDAGraph()

# Trace the forward and backward pass within the graph
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    static_y_pred = model(static_input)[0]
    static_loss = criterion(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()

# Real inputs for replaying the graph
real_inputs = [torch.zeros_like(static_input) for _ in range(iterations)]
real_targets = [torch.zeros_like(static_target) for _ in range(iterations)]

# Measure time for graphed version
torch.cuda.synchronize()
x_graphed = time.time()
for data, target in zip(real_inputs, real_targets):
    static_input.copy_(data)
    static_target.copy_(target)
    g.replay()  # Replay the traced graph
torch.cuda.synchronize()
print(f"graphed:total time {iterations} iterations: {time.time() - x_graphed}")
ungraphed: total time 200 iterations:  3.1128504276275635
partial graph: total time 200 iterations:  6.3470940589904785
graphed:total time 200 iterations: 0.7746648788452148

torch.cuda.make_graphed_callables() should be creating the same graph to the one I am creating manually. Is there a way to visualize both graphs so I can compare them? Or I would be grateful for any other suggestions on how I can improve the performance of torch.cuda.make_graphed_callables() as it is always worse than creating a graph manually.

I cannot reproduce the slowdown using torchvision.models.resnet101 and see:

ungraphed: total time 200 iterations:  5.168176889419556
partial graph: total time 200 iterations:  2.0263397693634033
graphed:total time 200 iterations: 1.7669649124145508

after adapting the target shapes.
Trying to install mmdet3d fails with:

ModuleNotFoundError: No module named 'setuptools.extern.six'

so I’m unsure which version is compatible to a recent PyTorch build.

You could profile your workload using e.g. Nsight Systems to narrow down where the slowdown might come from.

with pytorch resnet I also had the similar results (original post). But mmdet3d resnet it is slower.

Below is the method to setup the environment to run the above code


conda create -n pread python=3.10
conda activate pread

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

pip install -U openmim
mim install mmcv==2.1.0

git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -v -e .
# test if installed properly
mim download mmdet --config rtmdet_tiny_8xb32-300e_coco --dest .


# install mmdet3d
cd ..
git clone https://github.com/open-mmlab/mmdetection3d.git -b dev-1.x
cd mmdetection3d
pip install -v -e .
mim download mmdet3d --config pointpillars_hv_secfpn_8xb6-160e_kitti-3d-car --dest .
# should run successfully, if running locally should show a 3d point cloud window
python demo/pcd_demo.py demo/data/kitti/000008.bin pointpillars_hv_secfpn_8xb6-160e_kitti-3d-car.py hv_pointpillars_secfpn_6x8_160e_kitti-3d-car_20220331_134606-d42d15ed.pth --show

Not just here but I am trying to write code to graph parts of a large model and I realized not just mmdet3d torch.cuda.make_graphed_callables() is slower for quite a few models than original modet not just the mmdet3d resnet.