Tips for wrapping Conv layers in FSDP

Hey thanks for putting together the transformer_auto_wrap_policy for FSDP. I wanted to check if there are any tips as to which layers we can combine when we’re wrapping Conv blocks, or if wrapping the whole blocks in an FSDP unit should be good.

Any insight would be great!

Is it possible to provide a print out of your model definition? I may be able to provide more directed help in that case.

I was thinking of something like the CoAtNet models, the timm implementation for instance [1]

They are kinda tedious to print out so here’s the whole thing - timm coatnet definition as per https://arxiv.org/pdf/2106.04803v2.pdf · GitHub

The basic structure is that it has a few MbConv blocks, sequentially followed by Transformer blocks. I wrapped the transformer layers as described in the transformer_auto_wrap_policy and tried wrapping the Conv blocks with size_based_auto_wrap_policy but felt that was inefficient.

Looking forward to your suggestions!

[1] - pytorch-image-models/maxxvit.py at master · rwightman/pytorch-image-models · GitHub

Is the gist representative of your scaled model? If not and you are planning to scale the model further, would it be by increasing individual nn.Parameter sizes, adding more layers, or both?

I was thinking that something like the following might be good for scaling:

import functools
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={
        Stem,
        MaxxVitStage,
        TransformerBlock2d,  # possibly
        ClassifierHead,
    },
)

If the parameter sizes are exactly as in that gist, then including TransformerBlock2d may not be worth it. Let me know if you encounter any issues!

This is interesting, so it should be safe to add non-transformer classes to transformer_auto_wrap_policy? Wasn’t sure if adding Conv layers to it would be safe so I was thinking of adding a wrapper class around transformer_auto_wrap_policy and size_based. This is insightful thanks!

The name is a misnomer, and we need to address that. Apologies for the confusion.