I am new to FSDP and trying to implement it for my model training (replacing my DDP usage). My model is a normalizing flow (implemented with flowtorch), which is an invertible network. The train step looks like this:
def train_step(model: FSDP, batch_size: int):
...
loss = model(batch_size)
...
class Model(nn.Module):
...
def forward(self, batch_size):
return calc_loss(self, batch_size)
def calc_loss(model, batch_size):
samples = model.sample([batch_size])
log_prob = model.log_prob(samples)
return -log_prob
The normalizing flow is a composition of (trainable) bijectors. The sample method runs the flow forwards to produce random samples; it calls bijector.forward on each bijector. The log_prob method runs the flow in reverse to compute the probability densities of those samples; it calls bijector.inverse on each one.
The forward (sample) is fine. But when it does bijector.inverse (for the log_prob), I get errors suggesting that the shards were not gathered (“mat2 must be a matrix, got 1-D tensor”). Even when I just run the log_prob method outside the train_step (by hijacking the forward method of the model to call log_prob instead of calc_loss), I still get errors because the shards were not gathered.
Is there a solution to this? I would like to shard my parameters if possible, but I don’t know how to get FSDP to gather the shards again when I start running the units in reverse order to compute the log_prob. Right now I can only get it to run with ShardingStrategy.SHARD_GRAD_OP
.
Thanks in advance!