FSDP issue with invertible networks

I am new to FSDP and trying to implement it for my model training (replacing my DDP usage). My model is a normalizing flow (implemented with flowtorch), which is an invertible network. The train step looks like this:

def train_step(model: FSDP, batch_size: int):
    loss = model(batch_size)

class Model(nn.Module):

    def forward(self, batch_size):
        return calc_loss(self, batch_size)

def calc_loss(model, batch_size):
    samples = model.sample([batch_size])
    log_prob = model.log_prob(samples)
    return -log_prob

The normalizing flow is a composition of (trainable) bijectors. The sample method runs the flow forwards to produce random samples; it calls bijector.forward on each bijector. The log_prob method runs the flow in reverse to compute the probability densities of those samples; it calls bijector.inverse on each one.

The forward (sample) is fine. But when it does bijector.inverse (for the log_prob), I get errors suggesting that the shards were not gathered (“mat2 must be a matrix, got 1-D tensor”). Even when I just run the log_prob method outside the train_step (by hijacking the forward method of the model to call log_prob instead of calc_loss), I still get errors because the shards were not gathered.

Is there a solution to this? I would like to shard my parameters if possible, but I don’t know how to get FSDP to gather the shards again when I start running the units in reverse order to compute the log_prob. Right now I can only get it to run with ShardingStrategy.SHARD_GRAD_OP.

Thanks in advance!

Ok… Well my problem kinda fixed itself. I was experimenting with a stripped down version of my model, and I found out that if I set min_num_params to 1e6 instead of 1e7, the problem went away. I also changed up my code so that all my parameters have requires_grad equal to True, and made everything else just a buffer instead of a parameter. This was in an effort to set use_orig_params to False. Well, now everything works… Even if I go back to 1e7, or if I pass use_orig_params equals True.

Apparently this had nothing to do with invertible networks. Maybe it was an issue with parameters that didn’t require grads. Who knows!

1 Like