I am loving the new CUDAGraph functionality in PyTorch. I am trying to graph a transformer-based model, and if I fix the shapes to always use the maximum sequence length, then everything works great.
However, my training data comes in a few different sequence lengths. Let’s say for example’s sake I have 4 different sequence lengths: 16, 24, 32, and 48.
Would it be possible for me to capture multiple CUDAGraph for the same module, so that I have one replayable graph for that module for each shape/sequence length I can expect to give it?
I’m imagining something like the following:
module = MyModule() # create each graph based on shapes expected_shapes = [[16, 16], [16, 24], [16, 32], [16, 48]] rand_inputs = [torch.randn(shape) for shape in expected_shapes] graphs = [torch.cuda.make_graphed_callable(module, (input,)) for input in rand_inputs] # replay correct graph during training for sample in dataloader: # find correct graph idx based on shape graph_idx = get_graph_idx(sample.shape) output = graphs[graph_idx](sample)
So far this doesn’t seem to work, as the second call to
make_graphed_callable complains about mismatching shapes. Anything I can do here?