RuntimeError: setStorage Using FSDP

Hi, I am using FSDP but I get the error RuntimeError: setStorage: sizes [1, 1, 16], strides [16, 16, 1], storage offset 160, and itemsize 4 requiring a storage size of 704 are out of bounds for storage of size 0. A code snippet to reproduce the error is shown below:

I tested the code using both pytorch==2.3.1 and pytoch=1.12.1. For these two versions, I got the same error.

Could you help me debug this problem?

import os
import sys
import argparse
import torch
import functools

import torch.optim as optim
import torch.nn as nn

from torch.distributed import init_process_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local-rank', type=int, default=-1)
    parser.add_argument('--master-port', type=int, default=-1)
    return parser.parse_args()


class Learner(nn.Module):
    def __init__(self, depth=3):
        super().__init__()
        vec_lst = []
        for _ in range(depth):
            vec_lst.append(
                nn.ParameterList([
                    nn.Parameter(torch.rand(c+1, 16)) for c in range(4)
                ])
            )
        self.vec_lst = nn.ParameterList(vec_lst)

    def forward(self):
       return self.vec_lst 


class FreezeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.depth = 1

    def forward(self, x, vec_lst):
        bs = x.size(0)
        vec = vec_lst[self.depth]
        for i in range(4):
            x = x + vec[i].unsqueeze(0).expand(bs, -1, -1)
        return x


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.learner = Learner()
        self.freezed_model = FreezeModel()

    def forward(self, x):
        vec_lst = self.learner()
        x = self.freezed_model(x, vec_lst)
        x = x[:, 0, :]
        return x


def trainer_policy_fn(module):
        return isinstance(module, Learner)


def main():
    args = parse_arguments()
    torch.cuda.set_device(args.local_rank)
    init_process_group(
        init_method='env://',
        backend='nccl',
    )

    model = MyModel()
    for param in model.freezed_model.parameters():
        param.requires_grad_ = False

    cur_device = torch.cuda.current_device()
    my_auto_wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=trainer_policy_fn)
    model = FSDP(
        model,
        auto_wrap_policy=my_auto_wrap_policy,
        device_id=cur_device,
        use_orig_params=False,
        limit_all_gathers=True,
    )

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    input_tensor = torch.randn(4, 1, 16).cuda()
    target_tensor = torch.randn(4, 16).cuda(non_blocking=True)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        output = model(input_tensor)
        loss = criterion(output, target_tensor)
        loss.backward()
        optimizer.step()


if __name__ == '__main__':
    main()

what if you don’t set non_blocking to True? Also are you running it on one GPU or multiple GPUs?

what if you don’t set non_blocking to True? Also are you running it on one GPU or multiple GPUs?

When I do not set non_blocking to True, I still get the same error.

I am running it on multiple GPUs. The command I use is CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --master_port=1234 --nproc_per_node=2 test.py.

Parameters managed by one FullyShardedDataParallel module are only all-gathered (and hence usable) during its forward().

However, in your model, the FreezeModel.forward() is trying to use the Learner parameters outside of the Learner.forward(). This is why you are seeing the setStorage error (which just means that the storage has already been freed).

How you want to wrap your model depends on your actual model. If you have a more representative version, I can try to help a bit. Otherwise, if you just want your script to run, you can comment out the auto_wrap_policy arg. That way, all parameters are all-gathered upon the root MyModel.forward(). (Note though that this is not what you actually want to do for performance.)

P.S. setting param.requires_grad_ = False does not do anything in your script.

2 Likes

Thank you very much for the clarification! It is very helpful!

If you have a more representative version, I can try to help a bit.

Thank you for pointing out param.requires_grad_=False! Indeed, the code script I provide is a simplified version of my code for easy debugging. In my original code, there is a pretrained-model in FreezeModel class. That is why I use this argument. Without using auto_wrap_policy, the issue is that some parameters in FreezeModel class has frozen parameters while some parameters in Learner class have trainable parameters. The issue will be non-uniform requires_grad. So can you please help with a more representative version?