How to do inference using FSDP when model doesn't fit one GPU?

I’m trying to do simple inference on Llama 3 70B with 4 GPUs with 80GB each. I use TRL, but TRL only wraps the model with FSDP(…) wrapper. Should I do inference from the main process or from all? Why I see that my nodes get timeout after 10min with ALLREDUCE operation (while FSDP should not use that, right?) See the bug below.

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Warning 1: setting some parameters explicitly might result in issues inside PPOTrainer init
    # below while under Accelerate, e.g. if device_map set to "auto" and FSDP, it will fail in
    # prepare_model of Accelerate in PPOTrainer.
    # Warning 2: Avoid wrap model using Accelerator as it is done in PPOTrainer.
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        context.model_name,
        peft_config=lora_config,
        use_cache=False,
        device_map=None,
    )

    logger.info("Loaded model.")
    model.eval()
    with accelerator.main_process_first():
        logger.info("Create a dataset for training")
        steg_ds = build_dataset(context, tokenizer=tokenizer)
    accelerator.wait_for_everyone()

    ppo_trainer = PPOTrainer(
        context.ppo_config,
        model,
        ref_model=None,
        tokenizer=tokenizer,
        dataset=steg_ds["train"],
        data_collator=collator,
    )
    logger.info("PPOTrainer created.")
    model = ppo_trainer.model
    accelerator.wait_for_everyone()

    # Try inference:
    if accelerator.is_main_process:  # Should it be on one node? 
        logger.info("Inference test:")
        logger.info(ppo_trainer.model)
        with torch.no_grad():
            tensors = tokenizer(
                "Be concise. The capital of France is", return_tensors="pt"
            )["input_ids"].to(device)

            @torch.no_grad()
            def greedy_search(model, input_ids, max_new_tokens):
                res = []
                for _ in range(max_new_tokens):
                    # Avoid using model.generate here as it does not gather the weights for the LM head
                    # if FSDP. # See https://github.com/pytorch/pytorch/issues/100069
                    logger.debug(f"before forward {input_ids.shape=}")
                    new_token = model.forward(input_ids).argmax(dim=-1)[:, -1:]
                    res.append(new_token)
                    input_ids = torch.cat([input_ids, new_token], dim=-1)
                    logger.debug(f"{new_token=} {tokenizer.decode(new_token)}")
                res = torch.tensor(res).squeeze(1)
                return res

            res = greedy_search(ppo_trainer.model, tensors, 3)

The bug / timeout error:

...
[rank2]:[E611 11:51:04.293780141 ProcessGroupNCCL.cpp:572] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=487, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600012 milliseconds before timing out.
[rank2]:[E611 11:51:04.295536178 ProcessGroupNCCL.cpp:1587] [PG 0 (default_pg) Rank 2] Exception (either an error or timeout) detected by watchdog at work: 487, last enqueued NCCL work: 487, last completed NCCL work: 486.
[rank1]:[E611 11:51:04.302261338 ProcessGroupNCCL.cpp:572] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=487, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600020 milliseconds before timing out.
[rank1]:[E611 11:51:04.302629418 ProcessGroupNCCL.cpp:1587] [PG 0 (default_pg) Rank 1] Exception (either an error or timeout) detected by watchdog at work: 487, last enqueued NCCL work: 487, last completed NCCL work: 486.
[rank3]:[E611 11:51:04.344478076 ProcessGroupNCCL.cpp:572] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=487, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600062 milliseconds before timing out.
[rank3]:[E611 11:51:04.344840296 ProcessGroupNCCL.cpp:1587] [PG 0 (default_pg) Rank 3] Exception (either an error or timeout) detected by watchdog at work: 487, last enqueued NCCL work: 487, last completed NCCL work: 486.
compute-permanent-node-990:236514:237045 [2] NCCL INFO [Service thread] Connection closed by localRank 2
compute-permanent-node-990:236515:237047 [3] NCCL INFO [Service thread] Connection closed by localRank 3
...

My accelerate config:

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  #fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false