I am trying to make the model fast using cuda graphs. I have parts of model that cannot be graphed. I tried to benchmark resnet from pytorch and resnet in mmdet3d models.
import time
import torch
from mmdet3d.models import ResNet
iterations = 100
model_name = 'resnet152'
mmdet_model_size = 152
model = ResNet(mmdet_model_size) # torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=False)
model.train()
model = model.cuda()
criterion = torch.nn.MSELoss() # torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
static_input = torch.randn(2, 3, 224, 224).cuda()
static_target = torch.randn(2, 256, 56, 56).cuda()
# static_target = torch.randint(0, 1000, (2,)).long().cuda()
x = time.time()
for i in range(iterations):
optimizer.zero_grad(set_to_none=True)
out = model(static_input)[0]
loss = criterion(out, static_target)
loss.backward()
optimizer.step()
print(f"ungraphed: total time {iterations} iterations: ", time.time() - x)
model.cpu()
del model
torch.cuda.synchronize()
torch.cuda.empty_cache()
###################################################################
# parital graph v1
#######################################################
model = ResNet(mmdet_model_size) # torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=False)
model.train()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
for param_group in optimizer.param_groups:
param_group['capturable'] = True
print("Capturable:", param_group['capturable'])
model = torch.cuda.make_graphed_callables(model, (static_input,))
# warmup
for i in range(iterations):
optimizer.zero_grad(set_to_none=True)
out = model(static_input)[0]
loss = criterion(out, static_target)
loss.backward()
optimizer.step()
x = time.time()
for i in range(iterations):
optimizer.zero_grad(set_to_none=True)
out = model(static_input)[0]
loss = criterion(out, static_target)
loss.backward()
optimizer.step()
print(f"partial graph: total time {iterations} iterations: ", time.time() - x)
model.cpu()
del model
torch.cuda.synchronize()
torch.cuda.empty_cache()
###################################################################
# Fully Graphed
#######################################################
model = ResNet(mmdet_model_size) # torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=False)
model.train()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
for param_group in optimizer.param_groups:
param_group['capturable'] = True
print("Capturable:", param_group['capturable'])
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
optimizer.zero_grad(set_to_none=True)
y_pred = model(static_input)[0]
loss = criterion(y_pred, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
# Trace the forward and backward pass within the graph
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
static_y_pred = model(static_input)[0]
static_loss = criterion(static_y_pred, static_target)
static_loss.backward()
optimizer.step()
# Real inputs for replaying the graph
real_inputs = [torch.zeros_like(static_input) for _ in range(iterations)]
real_targets = [torch.zeros_like(static_target) for _ in range(iterations)]
# Measure time for graphed version
x_graphed = time.time()
for data, target in zip(real_inputs, real_targets):
static_input.copy_(data)
static_target.copy_(target)
g.replay() # Replay the traced graph
print(f"graphed:total time {iterations} iterations: {time.time() - x_graphed}")
Above is my code to test both models. In case I use simple Resnet, I see speedup in both partial and complete cuda graph case.
# paste results here fore resnet
ungraphed: total time 100 iterations: 7.005798816680908
Using cache found in /home/irdali.durrani/.cache/torch/hub/pytorch_vision_v0.10.0
partial graph: total time 100 iterations: 3.2496914863586426
Using cache found in /home/irdali.durrani/.cache/torch/hub/pytorch_vision_v0.10.0
graphed:total time 100 iterations: 1.4217324256896973
In case of mmdet3d resnet the results are below. Why is partial graph becoming slower than the original model?
ungraphed: total time 100 iterations: 1.8167531490325928
Capturable: True
partial graph: total time 100 iterations: 3.173295259475708
Capturable: True
graphed:total time 100 iterations: 0.34721994400024414