Question about cuda graph

I have a question on CUDA graph.

This is the typical usage of CUDA graph:

with torch.cuda.graph(g):
    static_y_pred = model(static_input)
    static_loss = loss_fn(static_y_pred, static_target)

My question is, when I build the model, if there are some cpu operations,
for example:

x = layer_0(x)
x = layer_1(x)
x = layer_2(x)
x = layer_3(x)

suppose I specifically want to run layer_1() or layer_2() on CPU (I can also say that I want to transfer the data between CPU and GPU and process the tensor).

Is cuda_graph able to captures these behavior? I guess there is a need to record the dependency before the capture and after the capture? If there is single one compute stream. It would be easy to do so?


gentle ping @mcarilli