Hey thanks for putting together the
transformer_auto_wrap_policy for FSDP. I wanted to check if there are any tips as to which layers we can combine when we’re wrapping Conv blocks, or if wrapping the whole blocks in an FSDP unit should be good.
Any insight would be great!
Is it possible to provide a print out of your model definition? I may be able to provide more directed help in that case.
I was thinking of something like the CoAtNet models, the timm implementation for instance 
They are kinda tedious to print out so here’s the whole thing - timm coatnet definition as per https://arxiv.org/pdf/2106.04803v2.pdf · GitHub
The basic structure is that it has a few MbConv blocks, sequentially followed by Transformer blocks. I wrapped the transformer layers as described in the
transformer_auto_wrap_policy and tried wrapping the Conv blocks with
size_based_auto_wrap_policy but felt that was inefficient.
Looking forward to your suggestions!
 - pytorch-image-models/maxxvit.py at master · rwightman/pytorch-image-models · GitHub
Is the gist representative of your scaled model? If not and you are planning to scale the model further, would it be by increasing individual
nn.Parameter sizes, adding more layers, or both?
I was thinking that something like the following might be good for scaling:
auto_wrap_policy = functools.partial(
TransformerBlock2d, # possibly
If the parameter sizes are exactly as in that gist, then including
TransformerBlock2d may not be worth it. Let me know if you encounter any issues!
This is interesting, so it should be safe to add non-transformer classes to
transformer_auto_wrap_policy? Wasn’t sure if adding Conv layers to it would be safe so I was thinking of adding a wrapper class around
size_based. This is insightful thanks!
The name is a misnomer, and we need to address that. Apologies for the confusion.