FSDP Unit Comprise the head and tail modules into a single flat parameter

As I understand, the FSDP splits the model into several FSDP units. However, there appears to be some inefficiency in the parameters grouping strategy. It may comprise the parameters of both head and tail modules into a single flat parameter. To illustrate this, let’s consider the following tiny demo.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

  my_auto_wrap_policy = functools.partial(
      size_based_auto_wrap_policy, min_num_params=2000
  model = Net().to(rank)
  model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)

# FullyShardedDataParallel(
#   (_fsdp_wrapped_module): Net(
#     (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
#     (conv2): FullyShardedDataParallel(
#       (_fsdp_wrapped_module): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
#     )
#     (dropout1): Dropout(p=0.25, inplace=False)
#     (dropout2): Dropout(p=0.5, inplace=False)
#     (fc1): FullyShardedDataParallel(
#       (_fsdp_wrapped_module): Linear(in_features=9216, out_features=128, bias=True)
#     )
#     (fc2): Linear(in_features=128, out_features=10, bias=True)
#   )
# )

The first FSDP unit combines the conv1 and fc2 layers in this demonstration. When we perform the forward pass for conv1, we only need the parameters of conv1, but the current setup fetches fc2 parameters, even though they are not required.

Are there any configs to eliminate this unnecessary parameter fetch and improve efficiency?"

cc @agu for fsdp autowrap

1 Like

Thanks for asking! The FSDP API does have limitation around how it can be applied to group parameters together for communication, namely that it must be specified on a single nn.Module.

Auto wrapping (i.e. auto_wrap_policy) is a syntactic sugar to help users apply FSDP. The size-based policy can often lead to inefficient constructions like you have pointed out, so we do not recommend it.

In cases like this, if you want “optimal” grouping, then you may need to rewrite your nn.Module to construct dummy parent modules for the modules you want to group together and apply FSDP to those dummy parents.

Hi @agu, thanks for your reply. As per your response, auto-wrapping is not recommended. Can you suggest an alternative?

Additionally, for the usage of dummy module and dummy pattern. Does it mean adding some dummy modules to adhere to size-based group rules? It would be helpful if you could provide some demonstrations or unit tests to clarify the concept.

Thanks a lot.