Torch.fx.symbolic_trace with multiple GPUs

Hi there,

I couldn’t find any documentation around this topic.

I’d like to torch.fx.symbolic_trace a large model that exceeds a single GPU memory e.g. LLaMa 405B. I have multiple GPUs. What’s the best way to go about doing this?

Thank you!