I’m trying to play with FSDP, however it seems like after wrapping, the model’s custom methods wouldn’t think that the model is wrapped. This is the minimally reproducible snippet:
from torch import Tensor, nn
import torch.distributed as dist
import os
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist_rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
dist.init_process_group("nccl", rank=dist_rank, world_size=world_size)
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(3,3)
def foo(self):
print(list(self.parameters()))
return 1
model = MyModel()
model2 = FSDP(model)
model2.foo()
here, even though model2.parameters() still has the full parameters, within foo it only printed out an empty list.
hmm ok, then I think my question would be on how to access the parameters. I found that before wrapping with FSDP, I could do model.parameters() no problem. After wrapping, model.parameters() would be an empty list, but the FSDP wrapped model2.parameters() still has the FlatParameter content. Is this also excepted? This would make accessing MyModel’s parameters inside foo to empty.
Or should I be accessing parameters differently?
(I edited the snippet in the original question a bit to reflect how I accessed the parameters inside foo)
our models have some custom functions that gets the optimizer for example. But if this is desired behavior (i.e we’re not supposed to access parameters within the model), then I can work around it. Although I do think it’s a little surprising behavior, where after wrapping in FSDP, model.parameters() would turn empty? Also why would that be the case?
Sorry, I did not quite follow. Could you explain in more detail how you are using .parameters() to get the optimizer?
After v1.13, we removed FlattenParamsWrapper. This means that the FlatParameters are registered to the wrapped module (e.g. the MyModel instance in your example). In that case, if you call model.foo() after model2 = FSDP(model), you should see the FlatParameters:
However, I ask for more clarification on item 1 because I want to see if there is a more recommended way to achieve what you want. Depending on the fact that the FlatParameters are registered to the wrapped module might not be robust.
I see. In that case, it should work after the change that landed after 1.13 that I mentioned. However, is there any reason why the optimizer cannot be created externally? Alternatively, could you make your method a static method and pass in the model, which can later be FSDP-wrapped?
Since the PR with the change landed after the 1.13 branch cut, that should mean it will only appear in the next release. You could access it from a nightly build if that works for your case.