Multiple CUDAGraphs for single model with different shape inputs

I am loving the new CUDAGraph functionality in PyTorch. I am trying to graph a transformer-based model, and if I fix the shapes to always use the maximum sequence length, then everything works great.

However, my training data comes in a few different sequence lengths. Let’s say for example’s sake I have 4 different sequence lengths: 16, 24, 32, and 48.

Would it be possible for me to capture multiple CUDAGraph for the same module, so that I have one replayable graph for that module for each shape/sequence length I can expect to give it?

I’m imagining something like the following:

module = MyModule()

# create each graph based on shapes
expected_shapes = [[16, 16], [16, 24], [16, 32], [16, 48]]
rand_inputs = [torch.randn(shape) for shape in expected_shapes]
graphs = [torch.cuda.make_graphed_callable(module, (input,)) for input in rand_inputs]

# replay correct graph during training
for sample in dataloader:
    # find correct graph idx based on shape
    graph_idx = get_graph_idx(sample.shape)
    output = graphs[graph_idx](sample)

So far this doesn’t seem to work, as the second call to make_graphed_callable complains about mismatching shapes. Anything I can do here?

A minimum working example of what I’d hope would work, but doesn’t, is printed here. Instead of varying sequence length I’m using varying batch sizes here, but the result is the same:

import torch

class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

        self.linear = torch.nn.Linear(5, 1)

    def forward(self, x):
        return self.linear(x)

if __name__ == "__main__":
    module = MyModule()

    expected_shapes = [[2, 5], [4, 5], [6, 5]]
    rand_inputs = [torch.randn(shape) for shape in expected_shapes]

    graphed_callables = []
    for input in rand_inputs:
        graphed_module = torch.cuda.make_graphed_callables(module, (input,))

    new_inputs = [torch.randn(shape) for shape in expected_shapes]
    for graphed_callable, input in zip(graphed_callables, new_inputs):

When running this, the second call to torch.cuda.make_graphed_callables throws a RuntimeError:

RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0