How to prevent `CUDA out of memory` error for a large MONAI network (SwinUNETR) with large patch-size images

For reference, I asked a similar question on the MONAI forum here, but couldn’t get a suitable response, so I am asking it here on the PyTorch forum to get more insights.

I am using the SwinUNETR network from the MONAI package (monai.networks.nets.SwinUNETR) for training a model for segmenting tumors from concatenated patches (along channel dimension) of PET/CT images.

The patches from PET/CT images are randomly cropped using the monai.transforms.RandCropByPosNegLabeld. I want the training crop sizes to be (192, 192, 192) voxels since I trained the UNet Model (monai.networks.nets.UNet) on (192, 192, 192) and that performed better than using other patch sizes. Also, for a given training patch size (N, N, N), the patch/window size (M, M, M) used for sliding window inference (monai.inferers.sliding_window_inference) matters too when it comes to getting an optimal Dice score on the validation set. This is illustrated in the following plots where for the UNet, the training patch is N = 96
(left) and N=192 (right) and for both cases, various inference patch sizes (M = 96, 128, 160, 192, 224, 256, 288) have been used for sliding window inference. So, the best performance for N = 96 occurs for (N, M) = (96, 256), while for N = 192, the best performance happens at (N, M) = (192, 192). Ultimately, my best performing UNet model is the one where (N, M)=(192, 192), and I want to compare this with the performance of SwinUNETR with the same (N, M) configuration.

Untitled

But unfortunately, the size of SwinUNETR depends on the size of the images and I seem to run out of GPU memory when I use (N, M) = (192, 192). Currently, I am running my experiments on a Microsoft Azure VM with 4 GPUs (each with 16 GiB RAM) (using DistributedDataParallel()). I was wondering if there are any other non-trivial ways of running this large model on these 4 GPUs without running out of memory. The largest (N, M) that I am able to comfortably run for SwinUNETR is (96, 96). Please let me know.

PS: I run out of memory even when I just use one concatenated patch of size (192, 192, 192) as input to the model.

Some suggestions (vaguely in order of increasing difficulty) in case you haven’t already tried them:

  1. Reducing the batch size
  2. Checking if you are on a recent/latest version of PyTorch built with CUDA 11.8+, as this will have support for lazy loading (which could save GPU memory for kernels that are not used by your workload). You may also want to set the CUDA_MODULE_LOADING="LAZY" environment variable.
  3. Consider using other forms of parallelism, e.g., FSDP Getting Started with Fully Sharded Data Parallel(FSDP) — PyTorch Tutorials 2.0.1+cu117 documentation or pipeline parallelism for your model.

Hi @eqy

Thanks for your suggestion. I started using the FSDP module as explained in the tutorial here. My model monai.networks.nets.SwinUNETR (from MONAI Network architectures — MONAI 1.1.0 Documentation) is still not able to train using an image patch-size (192, 192, 192) and I still get CUDA out of memory error. Here’s my code and the error message is given below:

#%%
#Import packages 
from monai.transforms import (
    AsDiscrete,
    Compose,
)
import monai
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, decollate_batch, ThreadDataLoader
import numpy
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist 
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    CPUOffload,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime
from initialize_train import (
    get_spatial_size,
    get_train_valid_data_in_dict_format, 
    get_train_transforms, 
    get_valid_transforms, 
    get_model, 
    get_loss_function,
    get_optimizer, 
    get_scheduler,
    get_metric,
)

#%%
g_gigabyte = 1024**3

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run


def convert_to_4digits(str_num):
    if len(str_num) == 1:
        new_num = '000' + str_num
    elif len(str_num) == 2:
        new_num = '00' + str_num
    elif len(str_num) == 3:
        new_num = '0' + str_num
    else:
        new_num = str_num
    return new_num

#%%
def load_train_objects():
    train_data, valid_data = get_train_valid_data_in_dict_format() 
    train_transforms = get_train_transforms()
    valid_transforms = get_valid_transforms()
    model = get_model('swinunetr')
    # optimizer = get_optimizer(model)
    loss_function = get_loss_function()
    # scheduler = get_scheduler(optimizer)
    metric = get_metric()

    return (
        train_data,
        valid_data,
        train_transforms,
        valid_transforms,
        model,
        loss_function,
        # optimizer,
        # scheduler,
        metric
    )


def prepare_dataset(data, transforms):
    dataset = CacheDataset(data=data, transform=transforms, cache_rate=0, num_workers=24)
    return dataset

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32,
)

#%%
def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num
#%%
def train(
        args, 
        model, 
        rank, 
        world_size, 
        train_dataloader, 
        loss_function, 
        optimizer, 
        scheduler, 
        epoch, 
        sampler=None
):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)
  
    if sampler:
        sampler.set_epoch(epoch)

    for batch_data in train_dataloader:
        inputs, labels = (
            batch_data['CTPT'].to(local_rank),
            batch_data['GT'].to(local_rank),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        fsdp_loss += loss.item()
        fsdp_loss[1] += len(batch_data)
     
    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_epoch_loss = fsdp_loss[0] / fsdp_loss[1]
    return train_epoch_loss


def validation(
        model, 
        rank, 
        world_size, 
        valid_dataloader, 
        metric, 
        post_pred,
        post_label,
):
    model.eval()
    local_rank = int(os.environ['LOCAL_RANK'])
   
    with torch.no_grad():
        for val_data in valid_dataloader:
                val_inputs, val_labels = (
                    val_data['CTPT'].to(local_rank),
                    val_data['GT'].to(local_rank),
                )
                roi_size = get_spatial_size()
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                metric(y_pred=val_outputs, y=val_labels)

            

        metric_val = metric.aggregate().item()
        metric.reset()
    if rank == 0:
       
        print(f"Validation Loss: {metric_val:.4f}")
    return metric_val
        

def fsdp_main(args):
    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    train_data, valid_data, train_transforms, valid_transforms, model, loss_function, metric = load_train_objects()
    print(f"[GPU{local_rank}]: Training objects initialized")
    train_dataset = prepare_dataset(train_data, train_transforms)
    valid_dataset = prepare_dataset(valid_data, valid_transforms)


    train_sampler = DistributedSampler(dataset=train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    valid_sampler = DistributedSampler(dataset=valid_dataset, rank=rank, num_replicas=world_size, shuffle=False)
    
    setup()
    
    # model, tokenizer = setup_model("t5-base")
    train_kwargs = {'batch_size': args.train_batch_size, 'sampler': train_sampler}
    valid_kwargs = {'batch_size': args.valid_batch_size, 'sampler': valid_sampler}
    cuda_kwargs = {'num_workers': 24,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    valid_kwargs.update(cuda_kwargs)

    train_dataloader = ThreadDataLoader(train_dataset, **train_kwargs)
    valid_dataloader = ThreadDataLoader(valid_dataset, **valid_kwargs)
    print(f"[GPU{local_rank}]: Train and Valid dataloaders initialized")
    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )
    
    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32
    
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())
    
    optimizer = get_optimizer(model)
    scheduler = get_scheduler(optimizer)

    best_metric = -1
    best_metric_epoch = -1
    valid_interval = 2
    epoch_loss_values = []
    metric_values = []

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []
    
    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])

    for epoch in range(args.epochs + 1):
        t0 = time.time()
        train_epoch_loss = train(
            args, 
            model, 
            rank, 
            world_size, 
            train_dataloader, 
            loss_function, 
            optimizer, 
            scheduler, 
            epoch, 
            sampler=train_sampler
        )
        epoch_loss_values.append(train_epoch_loss)
        print(f"[GPU{local_rank}]: Epoch: {epoch}, TrainLoss:{train_epoch_loss}")
        scheduler.step()

        if args.run_validation and (epoch + 1) % valid_interval == 0:
            valid_epoch_metric = validation(model, rank, world_size, valid_dataloader, metric, post_pred, post_label)
            metric_values.append(valid_epoch_metric)
        
        
        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_epoch_loss)

            if args.run_validation:
                val_acc_tracking.append(valid_epoch_metric)
            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")
        


        
    if rank == 0:
        print(f"Cuda event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
        print(f"{model}")

    dist.barrier()
    cleanup()


if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--train-batch-size', type=int, default=1, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--valid-batch-size', type=int, default=1, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')    
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    
    fsdp_main(args)
# %%

2023-05-19 02:07:41,939 - Created a temporary directory at /tmp/tmpw3cbnv_3
2023-05-19 02:07:41,939 - Created a temporary directory at /tmp/tmpcgdj9otm
2023-05-19 02:07:41,939 - Created a temporary directory at /tmp/tmpn4j0cp2h
2023-05-19 02:07:41,939 - Created a temporary directory at /tmp/tmpr1zjuesx
2023-05-19 02:07:41,940 - Writing /tmp/tmpw3cbnv_3/_remote_module_non_scriptable.py
2023-05-19 02:07:41,940 - Writing /tmp/tmpn4j0cp2h/_remote_module_non_scriptable.py
2023-05-19 02:07:41,940 - Writing /tmp/tmpcgdj9otm/_remote_module_non_scriptable.py
2023-05-19 02:07:41,940 - Writing /tmp/tmpr1zjuesx/_remote_module_non_scriptable.py
[GPU2]: Training objects initialized
[GPU0]: Training objects initialized
[GPU3]: Training objects initialized[GPU1]: Training objects initialized

2023-05-19 02:07:43,413 - Added key: store_based_barrier_key:1 to store for rank: 1
2023-05-19 02:07:43,414 - Added key: store_based_barrier_key:1 to store for rank: 0
2023-05-19 02:07:43,416 - Added key: store_based_barrier_key:1 to store for rank: 3
2023-05-19 02:07:43,418 - Added key: store_based_barrier_key:1 to store for rank: 2
2023-05-19 02:07:43,418 - Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[GPU2]: Train and Valid dataloaders initialized
2023-05-19 02:07:43,424 - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[GPU1]: Train and Valid dataloaders initialized
2023-05-19 02:07:43,424 - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[GPU0]: Train and Valid dataloaders initialized
2023-05-19 02:07:43,427 - Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[GPU3]: Train and Valid dataloaders initialized
--> current date and time of run = 2023-05-19-02:07:45_AM
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/__init__.py:120: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  t = cls([], dtype=storage.dtype, device=storage.device)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/__init__.py:120: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  t = cls([], dtype=storage.dtype, device=storage.device)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/__init__.py:120: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  t = cls([], dtype=storage.dtype, device=storage.device)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/__init__.py:120: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  t = cls([], dtype=storage.dtype, device=storage.device)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py:1295: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  ret = func(*args, **kwargs)
Traceback (most recent call last):
  File "trainfsdp_swinunetr.py", line 433, in <module>
    fsdp_main(args)
  File "trainfsdp_swinunetr.py", line 329, in fsdp_main
    train_epoch_loss = train(
  File "trainfsdp_swinunetr.py", line 186, in train
    loss.backward()
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 478, in backward
    return handle_torch_function(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/overrides.py", line 1551, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/meta_tensor.py", line 268, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 1295, in __torch_function__
    ret = func(*args, **kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 648.00 MiB (GPU 3; 15.78 GiB total capacity; 13.03 GiB already allocated; 429.50 MiB free; 14.19 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Traceback (most recent call last):
  File "trainfsdp_swinunetr.py", line 433, in <module>
    fsdp_main(args)
  File "trainfsdp_swinunetr.py", line 329, in fsdp_main
    train_epoch_loss = train(
  File "trainfsdp_swinunetr.py", line 186, in train
    loss.backward()
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 478, in backward
    return handle_torch_function(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/overrides.py", line 1551, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/meta_tensor.py", line 268, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 1295, in __torch_function__
    ret = func(*args, **kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 648.00 MiB (GPU 0; 15.78 GiB total capacity; 13.03 GiB already allocated; 429.50 MiB free; 14.19 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Traceback (most recent call last):
  File "trainfsdp_swinunetr.py", line 433, in <module>
    fsdp_main(args)
  File "trainfsdp_swinunetr.py", line 329, in fsdp_main
    train_epoch_loss = train(
  File "trainfsdp_swinunetr.py", line 186, in train
    loss.backward()
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 478, in backward
    return handle_torch_function(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/overrides.py", line 1551, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/meta_tensor.py", line 268, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 1295, in __torch_function__
    ret = func(*args, **kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 648.00 MiB (GPU 2; 15.78 GiB total capacity; 13.03 GiB already allocated; 429.50 MiB free; 14.19 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Traceback (most recent call last):
  File "trainfsdp_swinunetr.py", line 433, in <module>
    fsdp_main(args)
  File "trainfsdp_swinunetr.py", line 329, in fsdp_main
    train_epoch_loss = train(
  File "trainfsdp_swinunetr.py", line 186, in train
    loss.backward()
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 478, in backward
    return handle_torch_function(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/overrides.py", line 1551, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/monai/data/meta_tensor.py", line 268, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 1295, in __torch_function__
    ret = func(*args, **kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 648.00 MiB (GPU 1; 15.78 GiB total capacity; 13.03 GiB already allocated; 429.50 MiB free; 14.19 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 4882) of binary: /data/anaconda/envs/dlml/bin/python
Traceback (most recent call last):
  File "/data/anaconda/envs/dlml/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/data/anaconda/envs/dlml/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
trainfsdp_swinunetr.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-05-19_02:08:17
  host      : gpuvm00002jhubvm01.internal.cloudapp.net
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 4883)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2023-05-19_02:08:17
  host      : gpuvm00002jhubvm01.internal.cloudapp.net
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 4884)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2023-05-19_02:08:17
  host      : gpuvm00002jhubvm01.internal.cloudapp.net
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 4885)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-05-19_02:08:17
  host      : gpuvm00002jhubvm01.internal.cloudapp.net
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 4882)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

I also tried printing the number trainable parameters for both SwinUNETR and UNet using torchinfo.summary.

For the SwinUNETR, the output looks like this:

=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
SwinUNETR                                               [1, 2, 192, 192, 192]     --
├─SwinTransformer: 1-1                                  [1, 24, 96, 96, 96]       --
│    └─PatchEmbed: 2-1                                  [1, 24, 96, 96, 96]       --
│    │    └─Conv3d: 3-1                                 [1, 24, 96, 96, 96]       408
│    └─Dropout: 2-2                                     [1, 24, 96, 96, 96]       --
│    └─ModuleList: 2-3                                  --                        --
│    │    └─BasicLayer: 3-2                             [1, 48, 48, 48, 48]       37,230
│    └─ModuleList: 2-4                                  --                        --
│    │    └─BasicLayer: 3-3                             [1, 96, 24, 24, 24]       120,540
│    └─ModuleList: 2-5                                  --                        --
│    │    └─BasicLayer: 3-4                             [1, 192, 12, 12, 12]      425,400
│    └─ModuleList: 2-6                                  --                        --
│    │    └─BasicLayer: 3-5                             [1, 384, 6, 6, 6]         1,588,080
├─UnetrBasicBlock: 1-2                                  [1, 24, 192, 192, 192]    --
│    └─UnetResBlock: 2-7                                [1, 24, 192, 192, 192]    --
│    │    └─Convolution: 3-6                            [1, 24, 192, 192, 192]    1,296
│    │    └─InstanceNorm3d: 3-7                         [1, 24, 192, 192, 192]    --
│    │    └─LeakyReLU: 3-8                              [1, 24, 192, 192, 192]    --
│    │    └─Convolution: 3-9                            [1, 24, 192, 192, 192]    15,552
│    │    └─InstanceNorm3d: 3-10                        [1, 24, 192, 192, 192]    --
│    │    └─Convolution: 3-11                           [1, 24, 192, 192, 192]    48
│    │    └─InstanceNorm3d: 3-12                        [1, 24, 192, 192, 192]    --
│    │    └─LeakyReLU: 3-13                             [1, 24, 192, 192, 192]    --
├─UnetrBasicBlock: 1-3                                  [1, 24, 96, 96, 96]       --
│    └─UnetResBlock: 2-8                                [1, 24, 96, 96, 96]       --
│    │    └─Convolution: 3-14                           [1, 24, 96, 96, 96]       15,552
│    │    └─InstanceNorm3d: 3-15                        [1, 24, 96, 96, 96]       --
│    │    └─LeakyReLU: 3-16                             [1, 24, 96, 96, 96]       --
│    │    └─Convolution: 3-17                           [1, 24, 96, 96, 96]       15,552
│    │    └─InstanceNorm3d: 3-18                        [1, 24, 96, 96, 96]       --
│    │    └─LeakyReLU: 3-19                             [1, 24, 96, 96, 96]       --
├─UnetrBasicBlock: 1-4                                  [1, 48, 48, 48, 48]       --
│    └─UnetResBlock: 2-9                                [1, 48, 48, 48, 48]       --
│    │    └─Convolution: 3-20                           [1, 48, 48, 48, 48]       62,208
│    │    └─InstanceNorm3d: 3-21                        [1, 48, 48, 48, 48]       --
│    │    └─LeakyReLU: 3-22                             [1, 48, 48, 48, 48]       --
│    │    └─Convolution: 3-23                           [1, 48, 48, 48, 48]       62,208
│    │    └─InstanceNorm3d: 3-24                        [1, 48, 48, 48, 48]       --
│    │    └─LeakyReLU: 3-25                             [1, 48, 48, 48, 48]       --
├─UnetrBasicBlock: 1-5                                  [1, 96, 24, 24, 24]       --
│    └─UnetResBlock: 2-10                               [1, 96, 24, 24, 24]       --
│    │    └─Convolution: 3-26                           [1, 96, 24, 24, 24]       248,832
│    │    └─InstanceNorm3d: 3-27                        [1, 96, 24, 24, 24]       --
│    │    └─LeakyReLU: 3-28                             [1, 96, 24, 24, 24]       --
│    │    └─Convolution: 3-29                           [1, 96, 24, 24, 24]       248,832
│    │    └─InstanceNorm3d: 3-30                        [1, 96, 24, 24, 24]       --
│    │    └─LeakyReLU: 3-31                             [1, 96, 24, 24, 24]       --
├─UnetrBasicBlock: 1-6                                  [1, 384, 6, 6, 6]         --
│    └─UnetResBlock: 2-11                               [1, 384, 6, 6, 6]         --
│    │    └─Convolution: 3-32                           [1, 384, 6, 6, 6]         3,981,312
│    │    └─InstanceNorm3d: 3-33                        [1, 384, 6, 6, 6]         --
│    │    └─LeakyReLU: 3-34                             [1, 384, 6, 6, 6]         --
│    │    └─Convolution: 3-35                           [1, 384, 6, 6, 6]         3,981,312
│    │    └─InstanceNorm3d: 3-36                        [1, 384, 6, 6, 6]         --
│    │    └─LeakyReLU: 3-37                             [1, 384, 6, 6, 6]         --
├─UnetrUpBlock: 1-7                                     [1, 192, 12, 12, 12]      --
│    └─Convolution: 2-12                                [1, 192, 12, 12, 12]      --
│    │    └─ConvTranspose3d: 3-38                       [1, 192, 12, 12, 12]      589,824
│    └─UnetResBlock: 2-13                               [1, 192, 12, 12, 12]      --
│    │    └─Convolution: 3-39                           [1, 192, 12, 12, 12]      1,990,656
│    │    └─InstanceNorm3d: 3-40                        [1, 192, 12, 12, 12]      --
│    │    └─LeakyReLU: 3-41                             [1, 192, 12, 12, 12]      --
│    │    └─Convolution: 3-42                           [1, 192, 12, 12, 12]      995,328
│    │    └─InstanceNorm3d: 3-43                        [1, 192, 12, 12, 12]      --
│    │    └─Convolution: 3-44                           [1, 192, 12, 12, 12]      73,728
│    │    └─InstanceNorm3d: 3-45                        [1, 192, 12, 12, 12]      --
│    │    └─LeakyReLU: 3-46                             [1, 192, 12, 12, 12]      --
├─UnetrUpBlock: 1-8                                     [1, 96, 24, 24, 24]       --
│    └─Convolution: 2-14                                [1, 96, 24, 24, 24]       --
│    │    └─ConvTranspose3d: 3-47                       [1, 96, 24, 24, 24]       147,456
│    └─UnetResBlock: 2-15                               [1, 96, 24, 24, 24]       --
│    │    └─Convolution: 3-48                           [1, 96, 24, 24, 24]       497,664
│    │    └─InstanceNorm3d: 3-49                        [1, 96, 24, 24, 24]       --
│    │    └─LeakyReLU: 3-50                             [1, 96, 24, 24, 24]       --
│    │    └─Convolution: 3-51                           [1, 96, 24, 24, 24]       248,832
│    │    └─InstanceNorm3d: 3-52                        [1, 96, 24, 24, 24]       --
│    │    └─Convolution: 3-53                           [1, 96, 24, 24, 24]       18,432
│    │    └─InstanceNorm3d: 3-54                        [1, 96, 24, 24, 24]       --
│    │    └─LeakyReLU: 3-55                             [1, 96, 24, 24, 24]       --
├─UnetrUpBlock: 1-9                                     [1, 48, 48, 48, 48]       --
│    └─Convolution: 2-16                                [1, 48, 48, 48, 48]       --
│    │    └─ConvTranspose3d: 3-56                       [1, 48, 48, 48, 48]       36,864
│    └─UnetResBlock: 2-17                               [1, 48, 48, 48, 48]       --
│    │    └─Convolution: 3-57                           [1, 48, 48, 48, 48]       124,416
│    │    └─InstanceNorm3d: 3-58                        [1, 48, 48, 48, 48]       --
│    │    └─LeakyReLU: 3-59                             [1, 48, 48, 48, 48]       --
│    │    └─Convolution: 3-60                           [1, 48, 48, 48, 48]       62,208
│    │    └─InstanceNorm3d: 3-61                        [1, 48, 48, 48, 48]       --
│    │    └─Convolution: 3-62                           [1, 48, 48, 48, 48]       4,608
│    │    └─InstanceNorm3d: 3-63                        [1, 48, 48, 48, 48]       --
│    │    └─LeakyReLU: 3-64                             [1, 48, 48, 48, 48]       --
├─UnetrUpBlock: 1-10                                    [1, 24, 96, 96, 96]       --
│    └─Convolution: 2-18                                [1, 24, 96, 96, 96]       --
│    │    └─ConvTranspose3d: 3-65                       [1, 24, 96, 96, 96]       9,216
│    └─UnetResBlock: 2-19                               [1, 24, 96, 96, 96]       --
│    │    └─Convolution: 3-66                           [1, 24, 96, 96, 96]       31,104
│    │    └─InstanceNorm3d: 3-67                        [1, 24, 96, 96, 96]       --
│    │    └─LeakyReLU: 3-68                             [1, 24, 96, 96, 96]       --
│    │    └─Convolution: 3-69                           [1, 24, 96, 96, 96]       15,552
│    │    └─InstanceNorm3d: 3-70                        [1, 24, 96, 96, 96]       --
│    │    └─Convolution: 3-71                           [1, 24, 96, 96, 96]       1,152
│    │    └─InstanceNorm3d: 3-72                        [1, 24, 96, 96, 96]       --
│    │    └─LeakyReLU: 3-73                             [1, 24, 96, 96, 96]       --
├─UnetrUpBlock: 1-11                                    [1, 24, 192, 192, 192]    --
│    └─Convolution: 2-20                                [1, 24, 192, 192, 192]    --
│    │    └─ConvTranspose3d: 3-74                       [1, 24, 192, 192, 192]    4,608
│    └─UnetResBlock: 2-21                               [1, 24, 192, 192, 192]    --
│    │    └─Convolution: 3-75                           [1, 24, 192, 192, 192]    31,104
│    │    └─InstanceNorm3d: 3-76                        [1, 24, 192, 192, 192]    --
│    │    └─LeakyReLU: 3-77                             [1, 24, 192, 192, 192]    --
│    │    └─Convolution: 3-78                           [1, 24, 192, 192, 192]    15,552
│    │    └─InstanceNorm3d: 3-79                        [1, 24, 192, 192, 192]    --
│    │    └─Convolution: 3-80                           [1, 24, 192, 192, 192]    1,152
│    │    └─InstanceNorm3d: 3-81                        [1, 24, 192, 192, 192]    --
│    │    └─LeakyReLU: 3-82                             [1, 24, 192, 192, 192]    --
├─UnetOutBlock: 1-12                                    [1, 2, 192, 192, 192]     --
│    └─Convolution: 2-22                                [1, 2, 192, 192, 192]     --
│    │    └─Conv3d: 3-83                                [1, 2, 192, 192, 192]     50
=========================================================================================================
Total params: 15,703,868
Trainable params: 15,703,868
Non-trainable params: 0
Total mult-adds (G): 635.80
=========================================================================================================
Input size (MB): 56.62
Forward/backward pass size (MB): 16561.66
Params size (MB): 62.02
Estimated Total Size (MB): 16680.31
=========================================================================================================

While for UNet, the output is:

=================================================================================================================================================
Layer (type:depth-idx)                                                                          Output Shape              Param #
=================================================================================================================================================
UNet                                                                                            [1, 3, 192, 192, 192]     --
├─Sequential: 1-1                                                                               [1, 3, 192, 192, 192]     --
│    └─ResidualUnit: 2-1                                                                        [1, 16, 96, 96, 96]       --
│    │    └─Conv3d: 3-1                                                                         [1, 16, 96, 96, 96]       880
│    │    └─Sequential: 3-2                                                                     [1, 16, 96, 96, 96]       7,874
│    └─SkipConnection: 2-2                                                                      [1, 32, 96, 96, 96]       --
│    │    └─Sequential: 3-3                                                                     [1, 16, 96, 96, 96]       19,278,802
│    └─Sequential: 2-3                                                                          [1, 3, 192, 192, 192]     --
│    │    └─Convolution: 3-4                                                                    [1, 3, 192, 192, 192]     2,602
│    │    └─ResidualUnit: 3-5                                                                   [1, 3, 192, 192, 192]     246
=================================================================================================================================================
Total params: 19,290,404
Trainable params: 19,290,404
Non-trainable params: 0
Total mult-adds (G): 100.49
=================================================================================================================================================
Input size (MB): 56.62
Forward/backward pass size (MB): 2644.03
Params size (MB): 77.16
Estimated Total Size (MB): 2777.82
=================================================================================================================================================

So basically, UNet has more trainable parameters than SwinUNETR, but it still can fit in the 4 GPU (even with DataParallel() or DistributedDataParallel()), whereas the SwinUNETR cannot. Does someone understand this issue?

Note that memory usage scales differently with the number of parameters depending on the operation and its configuration. For example, you can increase memory usage arbitrarily with the the same number of parameters for a convolution while keeping the number of parameters the same if you increase the spatial dimensions of the input (and hence the output). Similarly, you can decrease the memory usage for a fixed number of parameters in a convolution by decreasing the spatial dimensions of the inputs or increasing the stride (both of which would decrease the size of the output).

At this point another thing you could try is activation checkpointing: torch.utils.checkpoint — PyTorch 2.0 documentation to trade recomputation cost for additional memory, but it is a bit involved and would require some experimentation and manual changes to your model. A basic strategy would to checkpoint the model into more and more segments (ideally with each segment being roughly equal in terms of memory usage) until the memory usage stops going down or you are satisfied with the peak memory usage.