Model design results in memory explosion

Lately I’ve been working on a diffusion model with U-net which simply takes an image and a timestamp(later embedded to latent vector) as inputs. The output is the predicted noise of the corrupted image at the timestep and the model structure is as follows.

The building blocks

class DoubleConv(nn.Module):
  def __init__(self, in_c, out_c):
    super().__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
        nn.GroupNorm(1, out_c), #equivalent with LayerNorm
        nn.ReLU()
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
        nn.GroupNorm(1, out_c), #equivalent with LayerNorm
        nn.ReLU()
    )
  
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    return x

class Down(nn.Module):
  def __init__(self, in_c, out_c, emb_dim=128):
    super().__init__()
    self.down = nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_c,out_c),
    )

    self.emb_layer = nn.Sequential(
        nn.ReLU(),
        nn.Linear(emb_dim, out_c),
    )

  def forward(self, x, t):
    x = self.down(x)
    t_emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) 
    return x + t_emb

class Up(nn.Module):
    def __init__(self, in_c, out_c, emb_dim=128):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = DoubleConv(in_c,out_c)
        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_c),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
      #input size: ([128])
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings #output size: ([128,128])

The U-net

the structure is based on this graph

which is implemented as

class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=128, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim

        self.inc = DoubleConv(c_in, 64) #(b,3,64,64) -> (b,64,64,64)

        self.down1 = Down(64, 128) #(b,64,64,64) -> (b,128,32,32)
        self.sa1 = SelfAttention(128, 32) #(b,128,32,32) -> (b,128,32,32)
        self.down2 = Down(128, 256) #(b,128,32,32) -> (b,256,16,16)
        self.sa2 = SelfAttention(256, 16) #(b,256,16,16) -> (b,256,16,16)
        self.down3 = Down(256, 256) #(b,256,16,16) -> (b,256,8,8)
        self.sa3 = SelfAttention(256, 8) #(b,256,8,8) -> (b,256,8,8)

        self.bot1 = DoubleConv(256, 512) #(b,256,8,8) -> (b,512,8,8)
        self.bot2 = DoubleConv(512, 512) #(b,512,8,8) -> (b,512,8,8)
        self.bot3 = DoubleConv(512, 256) #(b,512,8,8) -> (b,256,8,8)

        self.up1 = Up(512, 128) #(b,512,8,8) -> (b,128,16,16) because the skip_x
        self.sa4 = SelfAttention(128, 16) #(b,128,16,16) -> (b,128,16,16)
        self.up2 = Up(256, 64) #(b,256,16,16) -> (b,64,32,32)
        self.sa5 = SelfAttention(64, 32) #(b,64,32,32) -> (b,64,32,32)
        self.up3 = Up(128, 64) #(b,128,32,32) -> (b,64,64,64)
        self.sa6 = SelfAttention(64, 64) #(b,64,64,64) -> (b,64,64,64)

        self.outc = nn.Conv2d(64, c_out, kernel_size=1) #(b,64,64,64) -> (b,3,64,64)
    
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_dim),
                nn.Linear(time_dim, time_dim),
                nn.ReLU()
                )

    def forward(self, x, t):
        device = t.device
        time = self.time_mlp(t)

        #initial conv
        x1 = self.inc(x)
        
        #Down
        x2 = self.down1(x1, time)
        print(f'after down1: {torch.cuda.memory_allocated(device)}')
        x2 = self.sa1(x2)
        print(f'after sa1: {torch.cuda.memory_allocated(device)}')
        x3 = self.down2(x2, time)
        print(f'after down2: {torch.cuda.memory_allocated(device)}')
        x3 = self.sa2(x3)
        print(f'after sa2: {torch.cuda.memory_allocated(device)}')

        x4 = self.down3(x3, time)
        print(f'after down3: {torch.cuda.memory_allocated(device)}')
        x4 = self.sa3(x4)
        print(f'after sa3: {torch.cuda.memory_allocated(device)}')


        #Bottle neck
        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        #Up
        x = self.up1(x4, x3, time)
        print(f'after up1: {torch.cuda.memory_allocated(device)}')
        x = self.sa4(x)
        print(f'after sa4: {torch.cuda.memory_allocated(device)}')

        x = self.up2(x, x2, time)
        print(f'after up2: {torch.cuda.memory_allocated(device)}')
        x = self.sa5(x)
        print(f'after sa5: {torch.cuda.memory_allocated(device)}')

        x = self.up3(x, x1, time)
        print(f'after up3: {torch.cuda.memory_allocated(device)}')
        x = self.sa6(x)
        print(f'after sa6: {torch.cuda.memory_allocated(device)}')


        #Output
        output = self.outc(x)
        return output

The Main Problem

So I got this model on google colab and coded a train loop for it, using the colab GPU to accelerate. However, this error message appears:

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 GiB (GPU 0; 14.75 GiB total capacity; 9.15 GiB already allocated; 4.10 GiB free; 9.61 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I’ve tried out all the measures on the Internet, including reducing batch size, cleaning the cache and using python garbage collector. The problem isn’t resolved, so I doubt that there is a problem in my code. I’ve checked the training loop and I think there’s no variable that is related to the gradient accumulation (in fact, the loop stopped before the first iteration is even finished).

# the training loop
for epoch in range(1,addition_epochs+1):
    for step, batch in enumerate(dataloader): #batch = (features, labels) = ()
      optimizer.zero_grad()

      init_memory = torch.cuda.memory_allocated(device)
      print(f'init_memory: {torch.cuda.memory_allocated(device)}')

      t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
      x_noisy, noise = forward_diffusion_sample(batch[0], t, device)
      noise_pred = model(x_noisy, t)

      after_forward_memory = torch.cuda.memory_allocated(device)
      print(after_forward_memory)

      loss = F.l2_loss(noise, noise_pred)
      loss.backward() 

      after_backward_memory = torch.cuda.memory_allocated(device)
      print(after_backward_memory)

      optimizer.step()
      after_step_memory = torch.cuda.memory_allocated(device)
      print(after_step_memory)


      print(init_memory)

      if epoch % 5 == 0 and step == 0:
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
        sample_plot_image()

After breaking up my code and logging out the allocated GPU memory, I found that the problem is with my forward process, the output is as follows

# start training loop
init_memory: 65500160 (about 63 mb, the memory after model.to(device))
after down1: 1051363328 (about 100 mb)
after sa1: 3872032768 (started getting bigger)
after down2: 4090204160
after sa2: 4560490496
after down3: 4627666944
after sa3: 4720072704
after up1: 5038913536
after sa4: 5341427712
after up2: 5626707968
after sa5: 8111833088
after up3: 9017870336 (about 8.6 gb, out of memory)

This is only the first iteration of the first epoch, so I am so confused how come the memory increased so fast. I am pretty certain there’s some blunder that I made in the model structure that causes such issues such as variables, but I can’t find the problem.

If you have any thoughts or know the solution to this topic, please let me know.
Many thanks!

I would recommend checking the activation shapes and to compare them to the reported memory usage, as the high usage might be expected.
Generally, the memory usage comes from the parameter, the stored forward activations needed for the gradient computation, the gradients, and optimzer states (if applicable as it depends on the used optimizer).
This post gives you a simplified example and shows that the stored forward activations take much more memory than e.g. the parameters.
Based on the error message I would guess self.sa6 tries to allocate the 32GB of memory as the corresponding print statement is missing.

1 Like

Thank you for replying, ptrblck .

Sorry I’m not quite sure what you meant by ‘activation shape’, but the input shape of the image is (128, 3, 64, 64). I don’t think the problem is with the batch size since I tested the same dataset and input shape on another model as follows

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1 
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], 3, out_dim)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

I don’t see much difference between these two models except for the self attention block. The parameter in the above model is about 4 times of the current model, but the memory allocation is about 1 GB through out the first iteration and is working just fine under the same condition.

As you stated, the sa6 memory allocation wasn’t printed out, but it is because the program stopped before the sixth attention block forward process is even finished. I believe that the problem isn’t going to be resolved simply by increasing the GPU memory.

Does Pytorch reserve a certain memory before the training loop started and stop when the process is more then the reserved memory? I don’t know what causes the problem.

Big thanks anyway!

  1. What optimizer are you using?

  2. Are you accumulating the loss for statistics? If so, can you show that code?

Mr. J_Johnson, thanks for you reply.

What optimizer are you using?

I’m using Adam as my optimizer

from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=0.001)

and later push the optimizer to my device which is cuda

#push optimizer to current device
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

optimizer_to(optimizer,device)

Are you accumulating the loss for statistics?

I don’t think I accumulated the loss, as I checked this issue before. I’m using loss.item() in my training loop to avoid memory error.

for epoch in range(1,addition_epochs+1):
    for step, batch in enumerate(dataloader): 
      optimizer.zero_grad()

      t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
      x_noisy, noise = forward_diffusion_sample(batch[0], t, device)
      noise_pred = model(x_noisy, t)

      loss = F.l2_loss(noise, noise_pred)
      loss.backward() 

      optimizer.step()

      if epoch % 5 == 0 and step == 0:
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
        sample_plot_image()

On top of that, the iteration stopped before the loss is even calculated. The process is terminated during the very first forward process of my model.

Adam tends to be higher memory allocation than SGD: Optimizers memory usage

You could either:

  1. try reducing the number of parameters,
  2. set your model, train data and labels/targets to half precision, i.e. model.to(dtype=torch.float16)
  3. Use more memory efficient attention: GitHub - lucidrains/memory-efficient-attention-pytorch: Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Update to 3: I see Pytorch has implemented a memory efficient attention in Torch 2.0. See here: torch.nn.functional.scaled_dot_product_attention — PyTorch master documentation

By “activation” I meant the intermediate output activations created by each layer, which will also use memory.

This is exactly what I meant and suggested to check self.sa6 as it seems to allocate the reported 32GB.

I don’t think the optimizer or any other part is related and indeed self.sa6 needs 32GB:

class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        print("beginnign of forward: {}MB".format(torch.cuda.memory_allocated()/1024**2))
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        print("after view: {}MB".format(torch.cuda.memory_allocated()/1024**2))
        x_ln = self.ln(x)
        print("after layernorm: {}MB".format(torch.cuda.memory_allocated()/1024**2))
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        print("after mha: {}MB".format(torch.cuda.memory_allocated()/1024**2))
        attention_value = attention_value + x
        print("after addition: {}MB".format(torch.cuda.memory_allocated()/1024**2))
        attention_value = self.ff_self(attention_value) + attention_value
        print("after ff_self: {}MB".format(torch.cuda.memory_allocated()/1024**2))
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


print(torch.cuda.memory_allocated()/1024**2)
# 0.0

sa6 = SelfAttention(64, 64).cuda() #(b,64,64,64) -> (b,64,64,64)

print(torch.cuda.memory_allocated()/1024**2)
# 0.09814453125

x = torch.randn(128, 64, 64, 64, device="cuda")
print(torch.cuda.memory_allocated()/1024**2)
# 128.09814453125

out = sa6(x)
# beginnign of forward: 128.09814453125MB
# after view: 128.09814453125MB
# after layernorm: 260.09814453125MB
# OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 GiB (GPU 0; 23.69 GiB total capacity; 908.22 MiB already allocated; 21.00 GiB free; 1.15 GiB reserved in total by PyTorch)

Thanks for your suggestions, Mr. Johnson

It turns out that the problem is actually with the attention block.

I commented out all of the attention blocks and the model is now working just fine, so I assume that I can add those attention back when I have better GPU resources? In this case, I believe using a more memory efficient attention would definitely fix the issue, therefore I’m marking your answer as the solution.

Huge thank for you help!

Thank you for your wonderful analysis to the problem, Mr. ptrblck.

I’m certain that the problem is because of the attention block. After commenting out those attention blocks my model is now working just fine. I believe after acquiring a better GPU I will be able to add those attention back.

Your precise insight to my issue helped me a lot, and I’m truly grateful for your assistance.

Big thanks for your help!