[FSDP] Memory Duplication

Hi everyone, I am following this tutorial Advanced Model Training with Fully Sharded Data Parallel (FSDP) — PyTorch Tutorials 2.0.0+cu117 documentation
I change the task to the token classification but there are two main problems.
1st Problem (not related to FSDP):

  • It seems that Pytorch custom train loop uses more memory than Huggingface trainer (Hugging face: 2.8GB, Pytorch 6.7 GB)

2nd Problem:

  • The training process consumes about ~8GB RAM on 2 GPUs (each). I tried to fix this by using torch.cuda.emtpy_cache() after each training step. The memory did goes down, however, memory duplication was still there.

Here is my code snipet

import os
import tqdm
import time
import argparse
import functools
from datetime import datetime

import torch
import torch.optim as optim
import torch.distributed as dist

from torch.optim.lr_scheduler import StepLR
from transformers import AutoTokenizer, RobertaForTokenClassification
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from transformers.models.roberta.modeling_roberta import RobertaLayer

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
CPUOffload,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
from dataset import NerDataset

g_gigabyte = 1024**3

def setup():
# initialize the process group
dist.init_process_group(“nccl”)

def cleanup():
dist.destroy_process_group()

def get_date_of_run():
“”“create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM’
“””
date_of_run = datetime.now().strftime(“%Y-%m-%d-%I:%M:%S_%p”)
print(f"–> current date and time of run = {date_of_run}")
return date_of_run

fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)

def format_metrics_to_gb(item):
“”“quick function to format numbers to gigabyte and round to 4 digit precision”“”
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ[‘LOCAL_RANK’])
fsdp_loss = torch.zeros(2).to(local_rank)

if sampler:
    sampler.set_epoch(epoch)

if rank == 0:
    inner_pbar = tqdm.tqdm(
        range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
    )
for batch in train_loader:
    print("1")
    for key in batch.keys():
        batch[key] = batch[key].to(local_rank)
    print("2")
    optimizer.zero_grad()
    print("3")
    output = model(**batch)
    print("4")
    loss = output["loss"]
    loss.backward()
    print("5")
    optimizer.step()
    fsdp_loss[0] += loss.item()
    fsdp_loss[1] += len(batch)

    if rank==0:
        inner_pbar.update(1)
    torch.cuda.empty_cache()
    
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
train_loss = fsdp_loss[0] / fsdp_loss[1]


if rank == 0:
    inner_pbar.close()
    print(
            f"Train Epoch: \t{epoch}, Loss: \t{train_loss:.4f}"
        )
return train_loss

def validation(model, rank, world_size, val_loader):
model.eval()
local_rank = int(os.environ[‘LOCAL_RANK’])
fsdp_loss = torch.zeros(2).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour=“green”, desc=“Validation Epoch”
)
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(**batch)
fsdp_loss[0] += output[“loss”].item() # sum up batch loss
fsdp_loss[1] += len(batch)

        if rank==0:
            inner_pbar.update(1)
    
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
    inner_pbar.close()
    print(f"Validation Loss: {val_loss:.4f}")
return val_loss

def setup_model(model_name):
model = RobertaForTokenClassification.from_pretrained(model_name, num_labels=23)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer

def fsdp_main(args):

model, tokenizer = setup_model("phobert-base")

local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
train_dataset = NerDataset(tokenizer=tokenizer, file_path="phobert_train.json")
val_dataset = NerDataset(tokenizer=tokenizer, file_path="phobert_eval.json")

input_ids = torch.randint(low=0, high=10, size=(1, 20))

sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

setup()

train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 0,
                'pin_memory': True,
                'shuffle': False}

train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

train_loader = DataLoader(train_dataset, **train_kwargs)
val_loader = DataLoader(val_dataset, **test_kwargs)

roberta_auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={
        RobertaLayer,
    },
)
sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP # for Zero2 and FULL_SHARD for Zero3
torch.cuda.set_device(local_rank)

mp_policy = None # defaults to fp32

model = FSDP(model,
    auto_wrap_policy=roberta_auto_wrap_policy,
    mixed_precision=mp_policy,
    # sharding_strategy=sharding_strategy,
    cpu_offload=CPUOffload(offload_params=True),
    device_id=torch.cuda.current_device())
          
optimizer = optim.AdamW(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
best_val_loss = float("inf")
curr_val_loss = float("inf")
file_save_name = "Roberta-"

if rank == 0:
    time_of_run = get_date_of_run()
    dur = []
    train_acc_tracking = []
    val_acc_tracking = []
    training_start_time = time.time()

if rank == 0 and args.track_memory:
    mem_alloc_tracker = []
    mem_reserved_tracker = []

for epoch in range(1, args.epochs + 1):
    t0 = time.time()
    train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
    if args.run_validation:
        curr_val_loss = validation(model, rank, world_size, val_loader)
    scheduler.step()
    
    if rank == 0:

        print(f"--> epoch {epoch} completed...entering save and stats zone")

        dur.append(time.time() - t0)
        train_acc_tracking.append(train_accuracy.item())

        if args.run_validation:
            val_acc_tracking.append(curr_val_loss.item())

        if args.track_memory:
            mem_alloc_tracker.append(
                format_metrics_to_gb(torch.cuda.memory_allocated())
            )
            mem_reserved_tracker.append(
                format_metrics_to_gb(torch.cuda.memory_reserved())
            )
        print(f"completed save and stats zone...")
    

#init_end_event.record()
    
#if rank == 0:
    #print(f"Cuda event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
    # print(f"{model}")

    if args.save_model and curr_val_loss < best_val_loss:
        
        # save
        if rank == 0:
            print(f"--> entering save model state")
        
        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
        #print(f"saving process: rank {rank}  done w state_dict")
       

        if rank == 0:
            print(f"--> saving model ...")
            currEpoch = (
                "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
            )
            print(f"--> attempting to save model prefix {currEpoch}")
            save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
            print(f"--> saving as model name {save_name}")

            torch.save(cpu_state, save_name)
        
    if curr_val_loss < best_val_loss:

        best_val_loss = curr_val_loss
        if rank==0:
            print(f"-->>>> New Val Loss Record: {best_val_loss}")

dist.barrier()
cleanup()

if name == ‘main’:
import random
import numpy as np

def seed_everything(seed: int = 4) -> None:
    """ Ensure reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Roberta FSDP Example')
parser.add_argument('--batch-size', type=int, default=1, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
                    help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                    help='learning rate (default: .002)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                    help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--track_memory', action='store_false', default=True,
                    help='track the gpu memory')
parser.add_argument('--run_validation', action='store_false', default=True,
                    help='running the validation')
parser.add_argument('--save-model', action='store_false', default=True,
                    help='For Saving the current Model')
args = parser.parse_args()

seed_everything(args.seed)

import traceback 

try:
    fsdp_main(args)
except:
    print(traceback.format_exc())

I’m actually a bit surprised that the custom training loop have more memory usage than other trainers (maybe it’s because state_dict?). The code in this post is poorly formatted so I can’t tell which part is wrong

cc @agu @rvarm1 to check FSDP usage

@wanchaol Sorry for my poorly formatted code. Here is my complete code.

import os
import tqdm
import time
import argparse
import functools
from datetime import datetime

import torch
import torch.optim as optim
import torch.distributed as dist

from torch.optim.lr_scheduler import StepLR
from transformers import AutoTokenizer, RobertaForTokenClassification
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from transformers.models.roberta.modeling_roberta import RobertaLayer

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    CPUOffload,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
)
from dataset import NerDataset


g_gigabyte = 1024**3

def setup():
    # initialize the process group
    dist.init_process_group("nccl")


def cleanup():
    dist.destroy_process_group()

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32,
)


def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)
  
    if sampler:
        sampler.set_epoch(epoch)

    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        print("1")
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        print("2")
        optimizer.zero_grad()
        print("3")
        output = model(**batch)
        print("4")
        loss = output["loss"]
        loss.backward()
        print("5")
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)

        if rank==0:
            inner_pbar.update(1)
        # torch.cuda.empty_cache()
        
    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_loss = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_loss:.4f}"
            )
    return train_loss


def validation(model, rank, world_size, val_loader):
    model.eval()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(**batch)
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)
        
    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss


def setup_model(model_name):
        model = RobertaForTokenClassification.from_pretrained(model_name, num_labels=23)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        return model, tokenizer

def fsdp_main(args):

    model, tokenizer = setup_model("phobert-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    train_dataset = NerDataset(tokenizer=tokenizer, file_path="phobert_train.json")
    val_dataset = NerDataset(tokenizer=tokenizer, file_path="phobert_eval.json")
    
    input_ids = torch.randint(low=0, high=10, size=(1, 20))
    
    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()

    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 0,
                    'pin_memory': True,
                    'shuffle': False}
    
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = DataLoader(train_dataset, **train_kwargs)
    val_loader = DataLoader(val_dataset, **test_kwargs)
    
    roberta_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            RobertaLayer,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP # for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)
    
    mp_policy = None # defaults to fp32
    
    model = FSDP(model,
        auto_wrap_policy=roberta_auto_wrap_policy,
        mixed_precision=mp_policy,
        # sharding_strategy=sharding_strategy,
        cpu_offload=CPUOffload(offload_params=True),
        device_id=torch.cuda.current_device())
              
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "Roberta-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()
        
        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")
        

    #init_end_event.record()
        
    #if rank == 0:
        #print(f"Cuda event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
        # print(f"{model}")

        if args.save_model and curr_val_loss < best_val_loss:
            
            # save
            if rank == 0:
                print(f"--> entering save model state")
            
            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")
           

            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)
            
        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()


if __name__ == '__main__':
    import random 
    import numpy as np 

    def seed_everything(seed: int = 4) -> None:
        """ Ensure reproducibility"""
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        os.environ['PYTHONHASHSEED'] = str(seed)

    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Roberta FSDP Example')
    parser.add_argument('--batch-size', type=int, default=1, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    seed_everything(args.seed)
    
    import traceback 
    
    try:
        fsdp_main(args)
    except:
        print(traceback.format_exc())

@QuangTran Could you clarify what do you mean memory duplication? What is the expected behavior?

As I mentioned above, training with 1 GPU (w/o using FDSP) consumes 6.7GB GPU RAM. Using FSDP is expected to reduce the memory, so each GPU should consume less than 6.7GB, and maybe the total memory of both GPUs should be equal to or slightly higher than 6.7GB. However, when I run the script, both GPUs consume 8GB.