Lately I’ve been working on a diffusion model with U-net which simply takes an image and a timestamp(later embedded to latent vector) as inputs. The output is the predicted noise of the corrupted image at the timestep and the model structure is as follows.
The building blocks
class DoubleConv(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.GroupNorm(1, out_c), #equivalent with LayerNorm
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
nn.GroupNorm(1, out_c), #equivalent with LayerNorm
nn.ReLU()
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class Down(nn.Module):
def __init__(self, in_c, out_c, emb_dim=128):
super().__init__()
self.down = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_c,out_c),
)
self.emb_layer = nn.Sequential(
nn.ReLU(),
nn.Linear(emb_dim, out_c),
)
def forward(self, x, t):
x = self.down(x)
t_emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + t_emb
class Up(nn.Module):
def __init__(self, in_c, out_c, emb_dim=128):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConv(in_c,out_c)
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(emb_dim, out_c),
)
def forward(self, x, skip_x, t):
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
class SelfAttention(nn.Module):
def __init__(self, channels, size):
super(SelfAttention, self).__init__()
self.channels = channels
self.size = size
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)
def forward(self, x):
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
#input size: ([128])
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
# TODO: Double check the ordering here
return embeddings #output size: ([128,128])
The U-net
the structure is based on this graph
which is implemented as
class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=128, device="cuda"):
super().__init__()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64) #(b,3,64,64) -> (b,64,64,64)
self.down1 = Down(64, 128) #(b,64,64,64) -> (b,128,32,32)
self.sa1 = SelfAttention(128, 32) #(b,128,32,32) -> (b,128,32,32)
self.down2 = Down(128, 256) #(b,128,32,32) -> (b,256,16,16)
self.sa2 = SelfAttention(256, 16) #(b,256,16,16) -> (b,256,16,16)
self.down3 = Down(256, 256) #(b,256,16,16) -> (b,256,8,8)
self.sa3 = SelfAttention(256, 8) #(b,256,8,8) -> (b,256,8,8)
self.bot1 = DoubleConv(256, 512) #(b,256,8,8) -> (b,512,8,8)
self.bot2 = DoubleConv(512, 512) #(b,512,8,8) -> (b,512,8,8)
self.bot3 = DoubleConv(512, 256) #(b,512,8,8) -> (b,256,8,8)
self.up1 = Up(512, 128) #(b,512,8,8) -> (b,128,16,16) because the skip_x
self.sa4 = SelfAttention(128, 16) #(b,128,16,16) -> (b,128,16,16)
self.up2 = Up(256, 64) #(b,256,16,16) -> (b,64,32,32)
self.sa5 = SelfAttention(64, 32) #(b,64,32,32) -> (b,64,32,32)
self.up3 = Up(128, 64) #(b,128,32,32) -> (b,64,64,64)
self.sa6 = SelfAttention(64, 64) #(b,64,64,64) -> (b,64,64,64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1) #(b,64,64,64) -> (b,3,64,64)
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_dim),
nn.Linear(time_dim, time_dim),
nn.ReLU()
)
def forward(self, x, t):
device = t.device
time = self.time_mlp(t)
#initial conv
x1 = self.inc(x)
#Down
x2 = self.down1(x1, time)
print(f'after down1: {torch.cuda.memory_allocated(device)}')
x2 = self.sa1(x2)
print(f'after sa1: {torch.cuda.memory_allocated(device)}')
x3 = self.down2(x2, time)
print(f'after down2: {torch.cuda.memory_allocated(device)}')
x3 = self.sa2(x3)
print(f'after sa2: {torch.cuda.memory_allocated(device)}')
x4 = self.down3(x3, time)
print(f'after down3: {torch.cuda.memory_allocated(device)}')
x4 = self.sa3(x4)
print(f'after sa3: {torch.cuda.memory_allocated(device)}')
#Bottle neck
x4 = self.bot1(x4)
x4 = self.bot2(x4)
x4 = self.bot3(x4)
#Up
x = self.up1(x4, x3, time)
print(f'after up1: {torch.cuda.memory_allocated(device)}')
x = self.sa4(x)
print(f'after sa4: {torch.cuda.memory_allocated(device)}')
x = self.up2(x, x2, time)
print(f'after up2: {torch.cuda.memory_allocated(device)}')
x = self.sa5(x)
print(f'after sa5: {torch.cuda.memory_allocated(device)}')
x = self.up3(x, x1, time)
print(f'after up3: {torch.cuda.memory_allocated(device)}')
x = self.sa6(x)
print(f'after sa6: {torch.cuda.memory_allocated(device)}')
#Output
output = self.outc(x)
return output
The Main Problem
So I got this model on google colab and coded a train loop for it, using the colab GPU to accelerate. However, this error message appears:
OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 GiB (GPU 0; 14.75 GiB total capacity; 9.15 GiB already allocated; 4.10 GiB free; 9.61 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
I’ve tried out all the measures on the Internet, including reducing batch size, cleaning the cache and using python garbage collector. The problem isn’t resolved, so I doubt that there is a problem in my code. I’ve checked the training loop and I think there’s no variable that is related to the gradient accumulation (in fact, the loop stopped before the first iteration is even finished).
# the training loop
for epoch in range(1,addition_epochs+1):
for step, batch in enumerate(dataloader): #batch = (features, labels) = ()
optimizer.zero_grad()
init_memory = torch.cuda.memory_allocated(device)
print(f'init_memory: {torch.cuda.memory_allocated(device)}')
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
x_noisy, noise = forward_diffusion_sample(batch[0], t, device)
noise_pred = model(x_noisy, t)
after_forward_memory = torch.cuda.memory_allocated(device)
print(after_forward_memory)
loss = F.l2_loss(noise, noise_pred)
loss.backward()
after_backward_memory = torch.cuda.memory_allocated(device)
print(after_backward_memory)
optimizer.step()
after_step_memory = torch.cuda.memory_allocated(device)
print(after_step_memory)
print(init_memory)
if epoch % 5 == 0 and step == 0:
print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
sample_plot_image()
After breaking up my code and logging out the allocated GPU memory, I found that the problem is with my forward process, the output is as follows
# start training loop
init_memory: 65500160 (about 63 mb, the memory after model.to(device))
after down1: 1051363328 (about 100 mb)
after sa1: 3872032768 (started getting bigger)
after down2: 4090204160
after sa2: 4560490496
after down3: 4627666944
after sa3: 4720072704
after up1: 5038913536
after sa4: 5341427712
after up2: 5626707968
after sa5: 8111833088
after up3: 9017870336 (about 8.6 gb, out of memory)
This is only the first iteration of the first epoch, so I am so confused how come the memory increased so fast. I am pretty certain there’s some blunder that I made in the model structure that causes such issues such as variables, but I can’t find the problem.
If you have any thoughts or know the solution to this topic, please let me know.
Many thanks!