Hi, I need to check for graph capturing on a module that is compiled with torch.jit.script:
class Module(torch.nn.Module):
def forward(self, x):
if torch.cuda.graphs.is_current_stream_capturing():
print("Capturing underway")
module = torch.jit.script(Module())
However, this function is not supported by TorchScript and the following error is printed:
Python builtin <built-in function _cuda_isCurrentStreamCapturing> is currently not supported in Torchscript:
File "/shared/raul/mambaforge/envs/openmmtorch-test/lib/python3.10/site-packages/torch/cuda/graphs.py", line 25
If a CUDA context does not exist on the current device, returns False without initializing the context.
"""
return _cuda_isCurrentStreamCapturing()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Would you help me come up with a workaround?
Thanks