accumulation_steps = 4

epochs_times = 1

for epoch in range(epochs):

total_loss = 0

batch_times = 0

for idx, batch in enumerate(train_loader):

src = batch[‘source’].to(device)

tgt = batch[‘target’].to(device)

# Prepare target data for model input and loss calculation

tgt_input = batch[‘target’][:, :-1].to(device)

tgt_output = batch[‘target’].to(device)

`sequence_length = tgt_input.size(1) tgt_mask = model.get_tgt_mask(sequence_length).to(device)`

`pred = model(src, tgt_input, tgt_mask) # Reset gradients optimizer.zero_grad() # Forward pass output = model(src, tgt) # Reshape output for loss calculation output=output.permute(1,2,0) # Calculate loss loss = criterion(output, tgt_output) / accumulation_steps total_loss += loss.item() # Backward pass and optimize loss.backward() del src, tgt,tgt_input,tgt_output,loss if (idx + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() batch_times +=4 print(f"{batch_times}th batch of {epochs_times}th epoch is done") epochs_times +=1 avg_loss = total_loss / len(train_loader) print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')`

The codes above is my code to train my transformer model for translation.

While training the model, the memory usage suddenly increase like a step function.

And Out of Memory Errorr occurs

(The image above shows the usage graph of the VRAM)

At first, I tried to use other techniques like gradient accumulatio, but they did not work.

Would please help me to solve this problem?