I also tried printing the number trainable parameters for both SwinUNETR
and UNet
using torchinfo.summary
.
For the SwinUNETR
, the output looks like this:
=========================================================================================================
Layer (type:depth-idx) Output Shape Param #
=========================================================================================================
SwinUNETR [1, 2, 192, 192, 192] --
├─SwinTransformer: 1-1 [1, 24, 96, 96, 96] --
│ └─PatchEmbed: 2-1 [1, 24, 96, 96, 96] --
│ │ └─Conv3d: 3-1 [1, 24, 96, 96, 96] 408
│ └─Dropout: 2-2 [1, 24, 96, 96, 96] --
│ └─ModuleList: 2-3 -- --
│ │ └─BasicLayer: 3-2 [1, 48, 48, 48, 48] 37,230
│ └─ModuleList: 2-4 -- --
│ │ └─BasicLayer: 3-3 [1, 96, 24, 24, 24] 120,540
│ └─ModuleList: 2-5 -- --
│ │ └─BasicLayer: 3-4 [1, 192, 12, 12, 12] 425,400
│ └─ModuleList: 2-6 -- --
│ │ └─BasicLayer: 3-5 [1, 384, 6, 6, 6] 1,588,080
├─UnetrBasicBlock: 1-2 [1, 24, 192, 192, 192] --
│ └─UnetResBlock: 2-7 [1, 24, 192, 192, 192] --
│ │ └─Convolution: 3-6 [1, 24, 192, 192, 192] 1,296
│ │ └─InstanceNorm3d: 3-7 [1, 24, 192, 192, 192] --
│ │ └─LeakyReLU: 3-8 [1, 24, 192, 192, 192] --
│ │ └─Convolution: 3-9 [1, 24, 192, 192, 192] 15,552
│ │ └─InstanceNorm3d: 3-10 [1, 24, 192, 192, 192] --
│ │ └─Convolution: 3-11 [1, 24, 192, 192, 192] 48
│ │ └─InstanceNorm3d: 3-12 [1, 24, 192, 192, 192] --
│ │ └─LeakyReLU: 3-13 [1, 24, 192, 192, 192] --
├─UnetrBasicBlock: 1-3 [1, 24, 96, 96, 96] --
│ └─UnetResBlock: 2-8 [1, 24, 96, 96, 96] --
│ │ └─Convolution: 3-14 [1, 24, 96, 96, 96] 15,552
│ │ └─InstanceNorm3d: 3-15 [1, 24, 96, 96, 96] --
│ │ └─LeakyReLU: 3-16 [1, 24, 96, 96, 96] --
│ │ └─Convolution: 3-17 [1, 24, 96, 96, 96] 15,552
│ │ └─InstanceNorm3d: 3-18 [1, 24, 96, 96, 96] --
│ │ └─LeakyReLU: 3-19 [1, 24, 96, 96, 96] --
├─UnetrBasicBlock: 1-4 [1, 48, 48, 48, 48] --
│ └─UnetResBlock: 2-9 [1, 48, 48, 48, 48] --
│ │ └─Convolution: 3-20 [1, 48, 48, 48, 48] 62,208
│ │ └─InstanceNorm3d: 3-21 [1, 48, 48, 48, 48] --
│ │ └─LeakyReLU: 3-22 [1, 48, 48, 48, 48] --
│ │ └─Convolution: 3-23 [1, 48, 48, 48, 48] 62,208
│ │ └─InstanceNorm3d: 3-24 [1, 48, 48, 48, 48] --
│ │ └─LeakyReLU: 3-25 [1, 48, 48, 48, 48] --
├─UnetrBasicBlock: 1-5 [1, 96, 24, 24, 24] --
│ └─UnetResBlock: 2-10 [1, 96, 24, 24, 24] --
│ │ └─Convolution: 3-26 [1, 96, 24, 24, 24] 248,832
│ │ └─InstanceNorm3d: 3-27 [1, 96, 24, 24, 24] --
│ │ └─LeakyReLU: 3-28 [1, 96, 24, 24, 24] --
│ │ └─Convolution: 3-29 [1, 96, 24, 24, 24] 248,832
│ │ └─InstanceNorm3d: 3-30 [1, 96, 24, 24, 24] --
│ │ └─LeakyReLU: 3-31 [1, 96, 24, 24, 24] --
├─UnetrBasicBlock: 1-6 [1, 384, 6, 6, 6] --
│ └─UnetResBlock: 2-11 [1, 384, 6, 6, 6] --
│ │ └─Convolution: 3-32 [1, 384, 6, 6, 6] 3,981,312
│ │ └─InstanceNorm3d: 3-33 [1, 384, 6, 6, 6] --
│ │ └─LeakyReLU: 3-34 [1, 384, 6, 6, 6] --
│ │ └─Convolution: 3-35 [1, 384, 6, 6, 6] 3,981,312
│ │ └─InstanceNorm3d: 3-36 [1, 384, 6, 6, 6] --
│ │ └─LeakyReLU: 3-37 [1, 384, 6, 6, 6] --
├─UnetrUpBlock: 1-7 [1, 192, 12, 12, 12] --
│ └─Convolution: 2-12 [1, 192, 12, 12, 12] --
│ │ └─ConvTranspose3d: 3-38 [1, 192, 12, 12, 12] 589,824
│ └─UnetResBlock: 2-13 [1, 192, 12, 12, 12] --
│ │ └─Convolution: 3-39 [1, 192, 12, 12, 12] 1,990,656
│ │ └─InstanceNorm3d: 3-40 [1, 192, 12, 12, 12] --
│ │ └─LeakyReLU: 3-41 [1, 192, 12, 12, 12] --
│ │ └─Convolution: 3-42 [1, 192, 12, 12, 12] 995,328
│ │ └─InstanceNorm3d: 3-43 [1, 192, 12, 12, 12] --
│ │ └─Convolution: 3-44 [1, 192, 12, 12, 12] 73,728
│ │ └─InstanceNorm3d: 3-45 [1, 192, 12, 12, 12] --
│ │ └─LeakyReLU: 3-46 [1, 192, 12, 12, 12] --
├─UnetrUpBlock: 1-8 [1, 96, 24, 24, 24] --
│ └─Convolution: 2-14 [1, 96, 24, 24, 24] --
│ │ └─ConvTranspose3d: 3-47 [1, 96, 24, 24, 24] 147,456
│ └─UnetResBlock: 2-15 [1, 96, 24, 24, 24] --
│ │ └─Convolution: 3-48 [1, 96, 24, 24, 24] 497,664
│ │ └─InstanceNorm3d: 3-49 [1, 96, 24, 24, 24] --
│ │ └─LeakyReLU: 3-50 [1, 96, 24, 24, 24] --
│ │ └─Convolution: 3-51 [1, 96, 24, 24, 24] 248,832
│ │ └─InstanceNorm3d: 3-52 [1, 96, 24, 24, 24] --
│ │ └─Convolution: 3-53 [1, 96, 24, 24, 24] 18,432
│ │ └─InstanceNorm3d: 3-54 [1, 96, 24, 24, 24] --
│ │ └─LeakyReLU: 3-55 [1, 96, 24, 24, 24] --
├─UnetrUpBlock: 1-9 [1, 48, 48, 48, 48] --
│ └─Convolution: 2-16 [1, 48, 48, 48, 48] --
│ │ └─ConvTranspose3d: 3-56 [1, 48, 48, 48, 48] 36,864
│ └─UnetResBlock: 2-17 [1, 48, 48, 48, 48] --
│ │ └─Convolution: 3-57 [1, 48, 48, 48, 48] 124,416
│ │ └─InstanceNorm3d: 3-58 [1, 48, 48, 48, 48] --
│ │ └─LeakyReLU: 3-59 [1, 48, 48, 48, 48] --
│ │ └─Convolution: 3-60 [1, 48, 48, 48, 48] 62,208
│ │ └─InstanceNorm3d: 3-61 [1, 48, 48, 48, 48] --
│ │ └─Convolution: 3-62 [1, 48, 48, 48, 48] 4,608
│ │ └─InstanceNorm3d: 3-63 [1, 48, 48, 48, 48] --
│ │ └─LeakyReLU: 3-64 [1, 48, 48, 48, 48] --
├─UnetrUpBlock: 1-10 [1, 24, 96, 96, 96] --
│ └─Convolution: 2-18 [1, 24, 96, 96, 96] --
│ │ └─ConvTranspose3d: 3-65 [1, 24, 96, 96, 96] 9,216
│ └─UnetResBlock: 2-19 [1, 24, 96, 96, 96] --
│ │ └─Convolution: 3-66 [1, 24, 96, 96, 96] 31,104
│ │ └─InstanceNorm3d: 3-67 [1, 24, 96, 96, 96] --
│ │ └─LeakyReLU: 3-68 [1, 24, 96, 96, 96] --
│ │ └─Convolution: 3-69 [1, 24, 96, 96, 96] 15,552
│ │ └─InstanceNorm3d: 3-70 [1, 24, 96, 96, 96] --
│ │ └─Convolution: 3-71 [1, 24, 96, 96, 96] 1,152
│ │ └─InstanceNorm3d: 3-72 [1, 24, 96, 96, 96] --
│ │ └─LeakyReLU: 3-73 [1, 24, 96, 96, 96] --
├─UnetrUpBlock: 1-11 [1, 24, 192, 192, 192] --
│ └─Convolution: 2-20 [1, 24, 192, 192, 192] --
│ │ └─ConvTranspose3d: 3-74 [1, 24, 192, 192, 192] 4,608
│ └─UnetResBlock: 2-21 [1, 24, 192, 192, 192] --
│ │ └─Convolution: 3-75 [1, 24, 192, 192, 192] 31,104
│ │ └─InstanceNorm3d: 3-76 [1, 24, 192, 192, 192] --
│ │ └─LeakyReLU: 3-77 [1, 24, 192, 192, 192] --
│ │ └─Convolution: 3-78 [1, 24, 192, 192, 192] 15,552
│ │ └─InstanceNorm3d: 3-79 [1, 24, 192, 192, 192] --
│ │ └─Convolution: 3-80 [1, 24, 192, 192, 192] 1,152
│ │ └─InstanceNorm3d: 3-81 [1, 24, 192, 192, 192] --
│ │ └─LeakyReLU: 3-82 [1, 24, 192, 192, 192] --
├─UnetOutBlock: 1-12 [1, 2, 192, 192, 192] --
│ └─Convolution: 2-22 [1, 2, 192, 192, 192] --
│ │ └─Conv3d: 3-83 [1, 2, 192, 192, 192] 50
=========================================================================================================
Total params: 15,703,868
Trainable params: 15,703,868
Non-trainable params: 0
Total mult-adds (G): 635.80
=========================================================================================================
Input size (MB): 56.62
Forward/backward pass size (MB): 16561.66
Params size (MB): 62.02
Estimated Total Size (MB): 16680.31
=========================================================================================================
While for UNet
, the output is:
=================================================================================================================================================
Layer (type:depth-idx) Output Shape Param #
=================================================================================================================================================
UNet [1, 3, 192, 192, 192] --
├─Sequential: 1-1 [1, 3, 192, 192, 192] --
│ └─ResidualUnit: 2-1 [1, 16, 96, 96, 96] --
│ │ └─Conv3d: 3-1 [1, 16, 96, 96, 96] 880
│ │ └─Sequential: 3-2 [1, 16, 96, 96, 96] 7,874
│ └─SkipConnection: 2-2 [1, 32, 96, 96, 96] --
│ │ └─Sequential: 3-3 [1, 16, 96, 96, 96] 19,278,802
│ └─Sequential: 2-3 [1, 3, 192, 192, 192] --
│ │ └─Convolution: 3-4 [1, 3, 192, 192, 192] 2,602
│ │ └─ResidualUnit: 3-5 [1, 3, 192, 192, 192] 246
=================================================================================================================================================
Total params: 19,290,404
Trainable params: 19,290,404
Non-trainable params: 0
Total mult-adds (G): 100.49
=================================================================================================================================================
Input size (MB): 56.62
Forward/backward pass size (MB): 2644.03
Params size (MB): 77.16
Estimated Total Size (MB): 2777.82
=================================================================================================================================================
So basically, UNet
has more trainable parameters than SwinUNETR
, but it still can fit in the 4 GPU (even with DataParallel()
or DistributedDataParallel()
), whereas the SwinUNETR
cannot. Does someone understand this issue?