I was profiling the pytorch model, and theoretically according to my computations it should take no more than ~280 MiB, but the profiler says that the peak usage is 2.6 GiB.
I have a simple example of DCGAN
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
# Configuration for GPU execution
class Config:
dataset = 'fake'
dataroot = None
workers = 0
batchSize = 64
imageSize = 64
nz = 100
ngf = 64
ndf = 64
niter = 1 # Single epoch for quick testing
lr = 0.0002
beta1 = 0.5
dry_run = False # Quick single cycle test
ngpu = 1 # Use GPU
netG = ''
netD = ''
outf = '.'
manualSeed = None
classes = 'bedroom'
accel = True # Use acceleration
opt = Config()
print(opt.__dict__)
try:
os.makedirs(opt.outf)
except OSError:
pass
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
# Use GPU if available, fallback to CPU
if opt.accel and torch.accelerator.is_available():
device = torch.accelerator.current_accelerator()
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print("Warning: GPU requested but not available, falling back to CPU")
print(f"Using device: {device}")
if opt.dataroot is None and str(opt.dataset).lower() != 'fake':
raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset)
if opt.dataset in ['imagenet', 'folder', 'lfw']:
# folder dataset
dataset = dset.ImageFolder(root=opt.dataroot,
transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif opt.dataset == 'lsun':
classes = [ c + '_train' for c in opt.classes.split(',')]
dataset = dset.LSUN(root=opt.dataroot, classes=classes,
transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif opt.dataset == 'cifar10':
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif opt.dataset == 'mnist':
dataset = dset.MNIST(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]))
nc=1
elif opt.dataset == 'fake':
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
transform=transforms.ToTensor())
nc=3
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
def detailed_memory_info():
allocated = torch.cuda.memory_allocated() / 1024**2
cached = torch.cuda.memory_reserved() / 1024**2
max_allocated = torch.cuda.max_memory_allocated() / 1024**2
print(f"Allocated: {allocated:.1f} MiB, Cached: {cached:.1f} MiB, Max: {max_allocated:.1f} MiB")
torch.cuda.reset_peak_memory_stats()
print("Initial memory state:")
detailed_memory_info()
with torch.autograd.profiler.profile(use_device='cuda') as prof:
netG = Generator(ngpu).to(device)
print("After netG creation:")
detailed_memory_info()
netG.apply(weights_init)
print("After netG weights init:")
detailed_memory_info()
if opt.netG != '':
netG.load_state_dict(torch.load(opt.netG))
print("After netG load:")
detailed_memory_info()
print(netG)
netD = Discriminator(ngpu).to(device)
print("After netD creation:")
detailed_memory_info()
netD.apply(weights_init)
print("After netD weights init:")
detailed_memory_info()
if opt.netD != '':
netD.load_state_dict(torch.load(opt.netD))
print("After netD load:")
detailed_memory_info()
print(netD)
criterion = nn.BCELoss()
fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
print("After fixed_noise creation:")
detailed_memory_info()
real_label = 1
fake_label = 0
# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
print("After optimizers creation:")
detailed_memory_info()
if opt.dry_run:
opt.niter = 1
for epoch in range(opt.niter):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_cpu = data[0].to(device)
print("After real_cpu to device:")
detailed_memory_info()
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), real_label,
dtype=real_cpu.dtype, device=device)
print("After label creation:")
detailed_memory_info()
output = netD(real_cpu)
print("After netD forward:")
detailed_memory_info()
errD_real = criterion(output, label)
errD_real.backward()
print("After errD_real backward:")
detailed_memory_info()
D_x = output.mean().item()
# train with fake
noise = torch.randn(batch_size, nz, 1, 1, device=device)
print("After noise creation:")
detailed_memory_info()
fake = netG(noise)
print("After netG forward:")
detailed_memory_info()
label.fill_(fake_label)
output = netD(fake.detach())
print("After netD forward on fake:")
detailed_memory_info()
errD_fake = criterion(output, label)
errD_fake.backward()
print("After errD_fake backward:")
detailed_memory_info()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
print("After optimizerD step:")
detailed_memory_info()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
print("After netD forward on fake (G update):")
detailed_memory_info()
errG = criterion(output, label)
errG.backward()
print("After errG backward:")
detailed_memory_info()
D_G_z2 = output.mean().item()
optimizerG.step()
print("After optimizerG step:")
detailed_memory_info()
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch, opt.niter, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
if i % 100 == 0:
vutils.save_image(real_cpu,
'%s/real_samples.png' % opt.outf,
normalize=True)
fake = netG(fixed_noise)
print("After generating samples:")
detailed_memory_info()
vutils.save_image(fake.detach(),
'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
normalize=True)
if opt.dry_run:
break
# do checkpointing
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
print("DCGAN GPU training completed successfully!")
print("Final peak memory allocation:")
detailed_memory_info()
Here is the output:
{}
Random Seed: 3137
Using device: cuda
Initial memory state:
Allocated: 0.0 MiB, Cached: 0.0 MiB, Max: 0.0 MiB
After netG creation:
Allocated: 13.7 MiB, Cached: 22.0 MiB, Max: 13.7 MiB
After netG weights init:
Allocated: 13.7 MiB, Cached: 22.0 MiB, Max: 13.7 MiB
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
After netD creation:
Allocated: 24.2 MiB, Cached: 42.0 MiB, Max: 24.2 MiB
After netD weights init:
Allocated: 24.2 MiB, Cached: 42.0 MiB, Max: 24.2 MiB
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
After fixed_noise creation:
Allocated: 24.2 MiB, Cached: 42.0 MiB, Max: 24.2 MiB
After optimizers creation:
Allocated: 24.2 MiB, Cached: 42.0 MiB, Max: 24.2 MiB
After real_cpu to device:
Allocated: 27.2 MiB, Cached: 42.0 MiB, Max: 27.2 MiB
After label creation:
Allocated: 27.2 MiB, Cached: 42.0 MiB, Max: 27.2 MiB
After netD forward:
Allocated: 71.2 MiB, Cached: 152.0 MiB, Max: 817.2 MiB
After errD_real backward:
Allocated: 37.8 MiB, Cached: 82.0 MiB, Max: 2681.3 MiB
After noise creation:
Allocated: 37.8 MiB, Cached: 82.0 MiB, Max: 2681.3 MiB
After netG forward:
Allocated: 100.8 MiB, Cached: 114.0 MiB, Max: 2681.3 MiB
After netD forward on fake:
Allocated: 144.8 MiB, Cached: 224.0 MiB, Max: 2681.3 MiB
After errD_fake backward:
Allocated: 100.8 MiB, Cached: 226.0 MiB, Max: 2681.3 MiB
After optimizerD step:
Allocated: 122.9 MiB, Cached: 226.0 MiB, Max: 2681.3 MiB
After netD forward on fake (G update):
Allocated: 166.9 MiB, Cached: 300.0 MiB, Max: 2681.3 MiB
After errG backward:
Allocated: 77.4 MiB, Cached: 100.0 MiB, Max: 2681.3 MiB
After optimizerG step:
Allocated: 106.5 MiB, Cached: 142.0 MiB, Max: 2681.3 MiB
According to my calculation the maximum memory usage comprises of model weights, gradients, optimizer states, and activations.
Generator
🔍 Analyzing 14 layers...
------------------------------------------------------------
0. nn.ConvTranspose2d | Params: 819,200 | Memory: 3.12 MB
1. nn.BatchNorm2d | Params: 1,024 | Memory: 0.00 MB
2. nn.ReLU | Params: 0 | Memory: 0.00 MB
3. nn.ConvTranspose2d | Params: 2,097,152 | Memory: 8.00 MB
4. nn.BatchNorm2d | Params: 512 | Memory: 0.00 MB
5. nn.ReLU | Params: 0 | Memory: 0.00 MB
6. nn.ConvTranspose2d | Params: 524,288 | Memory: 2.00 MB
7. nn.BatchNorm2d | Params: 256 | Memory: 0.00 MB
8. nn.ReLU | Params: 0 | Memory: 0.00 MB
9. nn.ConvTranspose2d | Params: 131,072 | Memory: 0.50 MB
10. nn.BatchNorm2d | Params: 128 | Memory: 0.00 MB
11. nn.ReLU | Params: 0 | Memory: 0.00 MB
12. nn.ConvTranspose2d | Params: 3,072 | Memory: 0.01 MB
13. nn.Tanh | Params: 0 | Memory: 0.00 MB
------------------------------------------------------------
📊 TOTAL: 3,576,704 parameters, 13.64 MB
Discriminator
🔍 Analyzing 13 layers...
------------------------------------------------------------
0. nn.Conv2d | Params: 3,072 | Memory: 0.01 MB
1. nn.LeakyReLU | Params: 0 | Memory: 0.00 MB
2. nn.Conv2d | Params: 131,072 | Memory: 0.50 MB
3. nn.BatchNorm2d | Params: 256 | Memory: 0.00 MB
4. nn.LeakyReLU | Params: 0 | Memory: 0.00 MB
5. nn.Conv2d | Params: 524,288 | Memory: 2.00 MB
6. nn.BatchNorm2d | Params: 512 | Memory: 0.00 MB
7. nn.LeakyReLU | Params: 0 | Memory: 0.00 MB
8. nn.Conv2d | Params: 2,097,152 | Memory: 8.00 MB
9. nn.BatchNorm2d | Params: 1,024 | Memory: 0.00 MB
10. nn.LeakyReLU | Params: 0 | Memory: 0.00 MB
11. nn.Conv2d | Params: 8,192 | Memory: 0.03 MB
12. nn.Sigmoid | Params: 0 | Memory: 0.00 MB
------------------------------------------------------------
📊 TOTAL: 2,765,568 parameters, 10.55 MB
🎨 GENERATOR FORWARD PASS ANALYSIS
🚀 Forward pass analysis starting with shape: (64, 100, 1, 1)
================================================================================
0. nn.ConvTranspose2d | (64, 100, 1, 1) → (64, 512, 4, 4) | Sample: 0.03 MB | Batch: 2.00 MB
1. nn.BatchNorm2d | (64, 512, 4, 4) → (64, 512, 4, 4) | Sample: 0.03 MB | Batch: 2.00 MB
2. nn.ReLU | (64, 512, 4, 4) → (64, 512, 4, 4) | Sample: 0.03 MB | Batch: 2.00 MB
3. nn.ConvTranspose2d | (64, 512, 4, 4) → (64, 256, 8, 8) | Sample: 0.06 MB | Batch: 4.00 MB
4. nn.BatchNorm2d | (64, 256, 8, 8) → (64, 256, 8, 8) | Sample: 0.06 MB | Batch: 4.00 MB
5. nn.ReLU | (64, 256, 8, 8) → (64, 256, 8, 8) | Sample: 0.06 MB | Batch: 4.00 MB
6. nn.ConvTranspose2d | (64, 256, 8, 8) → (64, 128, 16, 16) | Sample: 0.12 MB | Batch: 8.00 MB
7. nn.BatchNorm2d | (64, 128, 16, 16) → (64, 128, 16, 16) | Sample: 0.12 MB | Batch: 8.00 MB
8. nn.ReLU | (64, 128, 16, 16) → (64, 128, 16, 16) | Sample: 0.12 MB | Batch: 8.00 MB
9. nn.ConvTranspose2d | (64, 128, 16, 16) → (64, 64, 32, 32) | Sample: 0.25 MB | Batch: 16.00 MB
10. nn.BatchNorm2d | (64, 64, 32, 32) → (64, 64, 32, 32) | Sample: 0.25 MB | Batch: 16.00 MB
11. nn.ReLU | (64, 64, 32, 32) → (64, 64, 32, 32) | Sample: 0.25 MB | Batch: 16.00 MB
12. nn.ConvTranspose2d | (64, 64, 32, 32) → (64, 3, 64, 64) | Sample: 0.05 MB | Batch: 3.00 MB
13. nn.Tanh | (64, 3, 64, 64) → (64, 3, 64, 64) | Sample: 0.05 MB | Batch: 3.00 MB
================================================================================
📊 Total activation memory for batch: 96.00 MB
📊 Total activation memory per sample: 1.50
🔍 DISCRIMINATOR FORWARD PASS ANALYSIS
🚀 Forward pass analysis starting with shape: (64, 3, 64, 64)
================================================================================
0. nn.Conv2d | (64, 3, 64, 64) → (64, 64, 32, 32) | Sample: 0.25 MB | Batch: 16.00 MB
1. nn.LeakyReLU | (64, 64, 32, 32) → (64, 64, 32, 32) | Sample: 0.25 MB | Batch: 16.00 MB
2. nn.Conv2d | (64, 64, 32, 32) → (64, 128, 16, 16) | Sample: 0.12 MB | Batch: 8.00 MB
3. nn.BatchNorm2d | (64, 128, 16, 16) → (64, 128, 16, 16) | Sample: 0.12 MB | Batch: 8.00 MB
4. nn.LeakyReLU | (64, 128, 16, 16) → (64, 128, 16, 16) | Sample: 0.12 MB | Batch: 8.00 MB
5. nn.Conv2d | (64, 128, 16, 16) → (64, 256, 8, 8) | Sample: 0.06 MB | Batch: 4.00 MB
6. nn.BatchNorm2d | (64, 256, 8, 8) → (64, 256, 8, 8) | Sample: 0.06 MB | Batch: 4.00 MB
7. nn.LeakyReLU | (64, 256, 8, 8) → (64, 256, 8, 8) | Sample: 0.06 MB | Batch: 4.00 MB
8. nn.Conv2d | (64, 256, 8, 8) → (64, 512, 4, 4) | Sample: 0.03 MB | Batch: 2.00 MB
9. nn.BatchNorm2d | (64, 512, 4, 4) → (64, 512, 4, 4) | Sample: 0.03 MB | Batch: 2.00 MB
10. nn.LeakyReLU | (64, 512, 4, 4) → (64, 512, 4, 4) | Sample: 0.03 MB | Batch: 2.00 MB
11. nn.Conv2d | (64, 512, 4, 4) → (64, 1, 1, 1) | Sample: 0.00 MB | Batch: 0.00 MB
12. nn.Sigmoid | (64, 1, 1, 1) → (64, 1, 1, 1) | Sample: 0.00 MB | Batch: 0.00 MB
================================================================================
📊 Total activation memory for batch: 74.00 MB
📊 Total activation memory per sample: 1.16 MB
Total memory usage:
# Final estimate
10.55 * 4 + 13.64 * 4 + 96.00 + 74.00 = 266