Every time I call the train function, the cpu ram usage keeps increasing. I am training on gpu, the gpu consumption is stable.
here is the code:
def train_model_memory_optimized(model, train_loader, val_loader, num_epochs=10, lr=0.001,
lambda_reg=0.01, device='cpu', early_stopping_patience=5,
accumulation_steps=4, use_mixed_precision=True):
"""
Trains a PyTorch model with validation and memory optimization techniques.
Args:
model (nn.Module): The PyTorch model to train.
train_loader (DataLoader): The training data loader.
val_loader (DataLoader): The validation data loader.
num_epochs (int, optional): The number of epochs to train for. Defaults to 10.
lr (float, optional): The learning rate. Defaults to 0.001.
lambda_reg (float, optional): L2 regularization strength. Defaults to 0.01.
device (str, optional): The device to use for training ('cpu' or 'cuda'). Defaults to 'cpu'.
early_stopping_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 5.
accumulation_steps (int, optional): Number of steps to accumulate gradients. Defaults to 1.
use_mixed_precision (bool, optional): Whether to use mixed precision training. Defaults to True.
Returns:
tuple: (best_model_state_dict, dict_metrics) where dict_metrics contains training history
"""
# Initialize model and optimization
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# Initialize gradient scaler for mixed precision
scaler = GradScaler() if use_mixed_precision and device.startswith('cuda') else None
# Initialize tracking variables
best_val = 0
best_model_state_dict = None
patience_counter = 0
# history = {
# 'train_loss': [],
# 'val_accuracy': [],
# 'val_precision': [],
# 'val_recall': [],
# 'val_f1': []
# }
for epoch in range(num_epochs):
# Training phase
model.train()
total_loss = 0
num_batches = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
# Mixed precision training
if use_mixed_precision and device.startswith('cuda'):
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
# Scale loss for gradient accumulation
loss = loss / accumulation_steps
# Scales loss and calls backward()
scaler.scale(loss).backward()
# Step with gradient accumulation
if (batch_idx + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
# More efficient than zero_grad()
for param in model.parameters():
param.grad = None
else:
outputs = model(inputs)
loss = criterion(outputs, targets)
# Scale loss for gradient accumulation
loss = loss / accumulation_steps
loss.backward()
# Step with gradient accumulation
if (batch_idx + 1) % accumulation_steps == 0:
optimizer.step()
# More efficient than zero_grad()
optimizer.zero_grad(set_to_none=True)
total_loss += loss.item() * accumulation_steps
num_batches += 1
# Explicitly delete intermediate tensors
del inputs, outputs, targets, loss
avg_train_loss = total_loss / num_batches
# history['train_loss'].append(avg_train_loss)
# Validation phase
if val_loader:
model.eval()
correct = 0
total = 0
# More memory-efficient tracking of predictions
all_targets = torch.tensor([], dtype=torch.long, device='cpu')
all_predictions = torch.tensor([], dtype=torch.long, device='cpu')
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
# For validation, always use full precision
outputs = model(inputs)
predictions = outputs.argmax(dim=1)
correct += (predictions == targets).sum().item()
total += targets.size(0)
# Move to CPU before concatenating
all_targets = torch.cat([all_targets, targets.cpu()])
all_predictions = torch.cat([all_predictions, predictions.cpu()])
# Explicitly delete tensors
del inputs, outputs, targets, predictions
# Calculate metrics (convert to numpy only once)
val_accuracy = correct / total
all_targets_np = all_targets.numpy()
all_predictions_np = all_predictions.numpy()
precision, recall, f1, _ = precision_recall_fscore_support(
all_targets_np,
all_predictions_np,
average='weighted',
zero_division=0
)
# Update history
# history['val_accuracy'].append(val_accuracy)
# history['val_precision'].append(precision)
# history['val_recall'].append(recall)
# history['val_f1'].append(f1)
# Print progress
print(f"Epoch {epoch+1}/{num_epochs}")
print(f"Train Loss: {avg_train_loss:.4f}")
print(f"Val Accuracy: {val_accuracy:.4f}")
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
print("-" * 50)
# Clear unnecessary variables
del all_targets, all_predictions, all_targets_np, all_predictions_np
# Model checkpoint (save state_dict instead of whole model)
if val_accuracy > best_val:
best_model_state_dict = model.state_dict().copy()
best_val = val_accuracy
patience_counter = 0
else:
patience_counter += 1
# Early stopping check
if patience_counter >= early_stopping_patience:
print(f"Early stopping triggered after epoch {epoch+1}")
break
# Free memory
torch.cuda.empty_cache()
gc.collect()
return best_model_state_dict```