Memory Usage During Training Skyroket

accumulation_steps = 4
epochs_times = 1

for epoch in range(epochs):
total_loss = 0
batch_times = 0
for idx, batch in enumerate(train_loader):
src = batch[‘source’].to(device)
tgt = batch[‘target’].to(device)
# Prepare target data for model input and loss calculation
tgt_input = batch[‘target’][:, :-1].to(device)
tgt_output = batch[‘target’].to(device)

    sequence_length = tgt_input.size(1)
    tgt_mask = model.get_tgt_mask(sequence_length).to(device)
    pred = model(src, tgt_input, tgt_mask)
    # Reset gradients

    # Forward pass
    output = model(src, tgt)

    # Reshape output for loss calculation

    # Calculate loss
    loss = criterion(output, tgt_output) / accumulation_steps
    total_loss += loss.item()

    # Backward pass and optimize
    del src, tgt,tgt_input,tgt_output,loss
    if (idx + 1) % accumulation_steps == 0:

        batch_times +=4
        print(f"{batch_times}th batch of {epochs_times}th epoch is done")
epochs_times +=1

avg_loss = total_loss / len(train_loader)
print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')

The codes above is my code to train my transformer model for translation.

While training the model, the memory usage suddenly increase like a step function.

And Out of Memory Errorr occurs

스크린샷 2024-02-05 194030

(The image above shows the usage graph of the VRAM)

At first, I tried to use other techniques like gradient accumulatio, but they did not work.

Would please help me to solve this problem?