Maximize fsdp throughput for Mistral on multinode

Is there anything obvious I should be doing?, I tried compiling the model but was getting dynamo errors. Not sure if there are any other practices that I’m missing.

I’m currently seeing a throughput of around 90 samples per second at max context length of 2600 tokens (but on average is only around 500 tokens) on 80 GPUs in my cluster. On a single node I get a throughput of around 11.2 samples per second on 8 GPUs and the best way is to do shard_op (zero stage 2) and no gradient checkpointing.

The main bottleneck is the networking, so having the largest possible batch size maximizes throughput since the networking communication bottlenecks almost at the same rate regardless of the bs. For such reason I ended up using HYBRID_SHARD_ZERO2 and enabling checkpointing to get a bs of 20 samples per gpu at 2600 max length.

These are the main parts to look at:

Model setup

Currently using HYBRID_SHARD_ZERO2 but have experimented with all the possibilities. Couldn’t get torch.compile to work. And had to enable gradient checkpointing to maximize batch size.

def setup_model(model_name, tokenizer):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    if len(tokenizer) > model.config.vocab_size:
        print(
            f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
        )
        model.resize_token_embeddings(
            int(8 * math.ceil(len(tokenizer) / 8.0))
        )  # make the vocab size multiple of 8 for sharding the embedding layer.

    assert model.__class__.__name__ in [
        "MistralForCausalLM"
    ], f"Model class name: {model.__class__.__name__} is not supported."

    model = FSDP(
        model,
        auto_wrap_policy=partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                MistralDecoderLayer,
            },
        ),
        # use_orig_params=True,
        limit_all_gathers=True,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2,
        device_id=torch.cuda.current_device(),
    )
    model.gradient_checkpointing_enable()
    # model = torch.compile(model)
    return model

training loop

importantly the use_cache=False, even though it is commented out gets set to True because only the gradient checkpointing works.

        for batch in train_loader:
            start = time.time()
            for k in batch:
                batch[k] = batch[k].to(local_rank)

            output = model(
                **batch,
                # use_cache=False,
            )

            loss = output["loss"]
            loss.backward()

            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()