CUDA out of memory while training

After reproducing the FQGAN implementation, CUDA out of memory occurred while training the model.

However, when I removed the feature quantization part that I implemented myself and ran it, no error occurred.

This means that the error occurs either in the model I implemented or in the FQ part where the loss value is added.

Here is the actual code. Maybe error occurred at Cell No.25/26/34.

[Colab] FQGAN

How can I fix the error?

I assume you’ve implemented the VectorQuantizerEMA module?
Are you seeing the OOM directly in the first iteration or are you seeing an increased memory usage?
In the former case, you might need to reduce the batch size or if that’s not possible you could try to use torch.utils.checkpoint to trade compute for memory.

This error is the latter.

After starting to train the model and doing about 500 iterations, an out of memory error is occurring.

This is not so much an issue of batch size or anything else, but rather the fact that the memory on the GPU is not being freed up during training and the computations are being stored.

When I actually let the module I implemented skip out of the forward propagation below, the learning went well with no errors.

def forward(self, inputs):
    h1 = self.layer1(inputs)
    h2 = self.layer2(h1)

    h3 = self.layer3(h2)
    # no error occured when I skip VectorQuantizerEMA 
    # h3, loss, perplexity = self.vq(h3)
        
    h4 = self.layer4(h3)
    out = self.sigmoid(self.last_conv(h4))
    return out, None, None

This leads me to believe that my implementation of VectorQuantizerEMA contains a memory-related error in its calculation.

Full VectorQuantizerEMA is below.

class VectorQuantizerEMA(nn.Module):
    def __init__(self, emb_dim, num_emb, commitment=0.25, decay=0.99, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self.emb_dim = emb_dim
        self.num_emb = num_emb
        self.commitment = commitment
        self.decay = decay
        self.epsilon = epsilon

        embed = torch.randn(emb_dim, num_emb)
        self.register_buffer('embed', embed)
        self.register_buffer('cluster_size', torch.zeros(num_emb))
        self.register_buffer('ema_embed', embed.clone())
    
    def forward(self, inputs):
        # [B, C=D, H, W] --> [B, H, W, C=D]
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        inputs_shape = inputs.size()

        # flatten: [B, H, W, C=D] --> [BxHxW=N, D]
        flatten = inputs.view(-1, self.emb_dim)

        # distance: d(W[N, D], E[D, K]) <-- [N, K]
        distance = (
            flatten.pow(2).sum(1, keepdim=True)
            -2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )

        # minimum index: [N, K] --> [N, ]
        embed_idx = torch.argmin(distance, dim=1)

        # quantize: [N, ]x[K, D] --> [N, D] --> [B, H, W, D]
        quantize = F.embedding(embed_idx, self.embed.transpose(0, 1))
        quantize = quantize.view(*inputs_shape)

        # set OneHot label: [N, ] --> [N, K]
        embed_onehot = F.one_hot(embed_idx, num_classes=self.num_emb).type(flatten.dtype)
        
        # train embedding vector only when model.eval()
        if self.training:
            # ref_counts: [N, K] --> [K, ]
            ref_counts = torch.sum(embed_onehot, dim=0)
            
            # ema for reference counts: [K, ]
            self.cluster_size = (
                self.decay * self.cluster_size + \
                (1 - self.decay) * ref_counts
            )

            # total reference count
            n = self.cluster_size.sum()

            # laplace smoothing
            self.cluster_size = n * (
                (self.cluster_size + self.epsilon)
                / (n + self.cluster_size * self.epsilon)
            )

            # dw: [D, N] @ [N, K]
            dw = flatten.transpose(0, 1) @ embed_onehot

            # ema for embeddings: [D, K]
            self.ema_embed = (
                self.decay * self.ema_embed + (1 - self.decay) * dw
            )

            # normalize by reference counts: [D, K] / [1, K] <-- [K, ]
            self.embed = self.ema_embed / self.cluster_size.unsqueeze(0)

        # loss
        e_latent_loss = F.mse_loss(quantize.detach(), inputs)
        loss = self.commitment * e_latent_loss

        # Straight Through Estimator
        quantize = inputs + (quantize - inputs).detach()
        # average probability: [N, K] --> [N, ]
        avg_probs = torch.mean(embed_onehot, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantize.permute(0, 3, 1, 2).contiguous(), None, None

One possibility of an increased memory usage might be the storage of the computation graph.
embed, cluster_size, and ema_embed are created as buffers, which would register the tensors without making them trainable (their requires_grad attribute would be False).
However, in the forward method you are reassigning some values to these buffers.
Could you check, if their grad_fn is suddently pointing to a valid function and if so, use detach() on the assignments:

self.cluster_size = (
                self.decay * self.cluster_size + \
                (1 - self.decay) * ref_counts
            ).detach()
# same for every other buffer assignment

When this error happened to me earlier, decreasing batch size resolved the issue. (From 64 I decreased to 20). As @ptrblck suggested in previous post.

As you pointed out, in the module that calculates exponential moving averages, it was due to the constant addition of a calculation graph for the variable I’m setting up as a buffer.

In the final code, I changed the method to edit the data directly using the in-place operation like add_(), instead of separating it from the current computation graph by detach().

The problem has been solved. Thank you!

[Colab] FQGAN succeeded