Torch.jit._get_trace_graph

Problem

Using segment_anything and fvcore but get an error in torch.jit._get_trace_graph().

When i used fvcore.nn.FlopCountAnalysis(sam,input_batch).by_module_and_operator(), which all produced by meta, i met the error below.

The function: graph, _ = _get_trace_graph(module, inputs) meets: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

I add module and inputs into cuda before this function. Also i met segfault when using python debug tools

I changed some of sam’s code (making its input just a list of images, and setting prompt words (dot, mask, bbox) to None).