Hi, I am trying to use torch.compile cache to compile the model on one GPU and then reuse the compiled model across different GPUs connected to the same node.
For instance, I have compiled a model on GPU0 (rank0) on a 8xH100 machine and then want to make other GPUs on the same node reuse the cache. This is how my code looks like and it doesn’t seem to work.
if rank == 0:
model_full = Model(apply_regional_compilation=False, hidden_size=hidden_size, num_layers=num_layers).to(device)
full_compiled_model = torch.compile(model_full)
full_compile_time = compile_only(full_compiled_model, input_tensor, rank, "full")
# Warmup (not included in compilation time)
warmup(full_compiled_model, input_tensor, rank, num_iters=3)
# Signal other ranks that compilation is complete
print(f"[Rank {rank}] Full compilation complete, signaling other ranks...")
artifacts = torch.compiler.save_cache_artifacts()
assert artifacts is not None
artifact_bytes, cache_info = artifacts
print(f"[Rank {rank}] Cache info: {cache_info}")
# Broadcast artifact_bytes from rank 0 to all other ranks
artifact_bytes_list = [artifact_bytes]
dist.broadcast_object_list(artifact_bytes_list, src=0)
else:
artifact_bytes_list = [None]
dist.broadcast_object_list(artifact_bytes_list, src=0)
artifact_bytes = artifact_bytes_list[0]
torch.compiler.load_cache_artifacts(artifact_bytes)
model_full = Model(apply_regional_compilation=False, hidden_size=hidden_size, num_layers=num_layers).to(device)
full_compiled_model = torch.compile(model_full)
full_compile_time = compile_only(full_compiled_model, input_tensor, rank, "full (FROM CACHE)")
warmup(full_compiled_model, input_tensor, rank, num_iters=3)
dist.barrier()
I launch the experiment using - export TORCH_LOGS="+torch._inductor.codecache"; torchrun --standalone --nproc_per_node 2 multi_gpu_compile.py
Looking through the logs generated by TORCH_LOGS=+torch._inductor.codecache, I see FX graph has different hashes for each rank and my guess is that’s why every instance of the model on each GPU just goes through the entire compilation process. Is this expected?
Is there a way around this - both in a single-node scenario and multi-node scenario?