No, the GradScaler
will not keep unused references around and thus increase the memory usage.
I would recommend to check all returned tensors e.g. from check_accuracy
and make sure none of them has a valid .grad_fn
since you are storing these tensors.
Thank you for your kind reply. I tried to delete some variables in the check_accuracy
routine (see code below), but it does not work. I just deleted everything. Even the batch size used was 1, and the training dataset had 247 images of 224x224 (very low in size), with 11 images in the testing dataset. But the out-of-memory eventually occurs in the second epoch (66%, after 158 images and 29 seconds). There is any hint? I start to think that the model is increasing consumption, but I donât know how it could. Is it really possible? I tested already the same code with another ResNet model with 152 layers, with a 20.000 images dataset, and it worked fine (this current one is a U-Net with ResNet with 50 layers).
def check_accuracy(loader, model, loss_fn, device='cuda' if torch.cuda.is_available() else 'cpu'):
num_correct = 0
num_pixels = 0
dice_score = 0
model.eval()
loop = tqdm(loader, desc='Check acc')
with torch.no_grad():
for dictionary in loop:
image, label = dictionary
x, y = dictionary[image], dictionary[label]
x, y = x.to(device=device), y.to(device=device)
y = y.float()
pred = model(x)
y = tf.center_crop(y, pred.shape[2:])
pred = (pred > 0.5).float()
loss = loss_fn(pred, y)
num_correct += (pred == y).sum()
num_pixels += torch.numel(pred)
smooth = 1e-4
dice_score += (2*100*(pred*y).sum()+smooth) / ((pred+y).sum()+smooth)
loop.set_postfix(acc=str(round(100*num_correct.item()/int(num_pixels),4)))
# deliting variables
loss_item = loss.item()
del loss, pred, x, y, image, label, dictionary
# deliting variables
num_correct_item = num_correct.item()
num_pixels = int(num_pixels)
dice_score_item = dice_score.item()
len_loader = len(loader)
del num_correct, dice_score, loader, loop
print(f'\nGot an accuracy of {round(100*num_correct_item/int(num_pixels),4)}')
print(f'Dice score: {round(dice_score_item/len_loader,4)}\n')
model.train()
return 100*num_correct_item/num_pixels, loss_item, dice_score_item/len_loader
I donât see any obvious issues and would recommend to narrow down the code even further by only using the model with its corresponding training routine. If this is still increasing the memory usage, could you post a minimal, executable code snippet to reproduce the issue, please?
Thank you very much for your fast reply. It was completely solved by removing an append
in a self variable
inside my model. Now it works fine with the whole dataset, and with far more batch size.
Hello, I have the sam problem. I try to train a transformer model on a gpu cluster but my gpu load always increases while training until the execution completely fails. I tried all of the suggestions above but they did not seem to work. Here the train function:
def trainLoop(data, model, loadModel, modelName, lr, weightDecay, earlyStopping, epochs,
validationSet, validationStep, WandB, device, pathOrigin = pathOrigin):
"""
data: list of list of input data and dates and targets
model: pytorch nn.class
loadModel: boolean
modelName: string
.pth.tar model name on harddrive
lr: float
weightDecay: float
earlyStopping: float
criterionFunction: nn.lossfunction
epochs: int
validationSet: same as data
validationStep: int
timepoint when validation set is evaluated for early stopping regularization
WandB: boolean
use weights and biases tool to monitor losses dynmaically
device: string
device on which the data should be stored
return: nn.class
trained model
"""
torch.autograd.set_detect_anomaly(True)
runningLoss = 0
runningLossLatentSpace = np.zeros(len(data) * epochs)
meanRunningLossLatentSpace = 0
runningLossReconstruction = np.zeros(len(data) * epochs)
meanRunningLossReconstruction = 0
stoppingCounter = 0
lastLoss = 0
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weightDecay)
trainLosses = np.zeros(len(data) * epochs)
validationLosses = np.zeros((len(data) * epochs, 2))
trainCounter = 0
meanValidationLoss = 0
# WandB
if WandB:
wandb.init(
# set the wandb project where this run will be logged
project= modelName,
# track hyperparameters and run metadata
config={
"learning_rate": lr,
"architecture": modelName,
"dataset": "Helheim, Aletsch, jakobshavn",
"epochs": epochs,
}
)
# load model
if loadModel:
loadCheckpoint(torch.load(modelName), model=model, optimizer=optimizer)
model.train()
for x in range(epochs):
# get indices for epoch
ix = np.arange(0, len(data), 1)
ix = np.random.choice(ix, len(data), replace=False, p=None)
for i in ix:
# get data
helper = data[i]
# move to cuda
helper = moveToCuda(helper, device)
#define target
y = helper[1][0]
# zero the parameter gradients
optimizer.zero_grad()
model.zero_grad()
# forward + backward + optimize
forward = model.forward(helper, training = True)
predictions = forward[0]
loss = MSEpixelLoss(predictions, y) + forward[1] + forward[2] # output loss, latent space loss, recopnstruction loss
loss.backward()
optimizer.step()
trainCounter += 1
# print loss
#meanRunningLossLatentSpace += forward[1].item()
#meanRunningLossLatentSpace = meanRunningLossLatentSpace/trainCounter
#runningLossLatentSpace[trainCounter - 1] = meanRunningLossLatentSpace
#meanRunningLossReconstruction += forward[2].item()
#meanRunningLossReconstruction = meanRunningLossReconstruction/trainCounter
#runningLossReconstruction[trainCounter - 1] = meanRunningLossReconstruction
#runningLoss += loss.item()
#meanRunningLoss = runningLoss / trainCounter
#trainLosses[trainCounter - 1] = meanRunningLoss
## log to wandb
if WandB:
wandb.log({#"train loss": meanRunningLoss,
"train loss": loss.item()})
#"latentSpaceLoss": meanRunningLossLatentSpace,
#"reconstructionLoss": meanRunningLossReconstruction,
#"validationLoss": meanValidationLoss})
"""
if i % validationStep == 0 and i != 0:
if validationSet != None:
# sample data
validationLoss = 0
for i in range(len(validationSet)):
helper = validationSet[i]
# move to cuda
helper = moveToCuda(helper, device)
y = helper[1][0]
# forward + backward + optimize
forward = model.forward(helper, training = True)
# predictions = forward[0].to(device='cuda')
predictions = forward[0]
testLoss = MSEpixelLoss(predictions, y) + forward[1] + forward[2]
validationLoss += testLoss.item()
meanValidationLoss = validationLoss / len(validationSet)
validationLosses[trainCounter - 1] = np.array([meanValidationLoss, trainCounter]) # save trainCounter as well for comparison with interpolation
# of in between datapoints
# save memory
del forward, helper
print("current validation loss: ", meanValidationLoss)
# early stopping
if earlyStopping > 0:
if lastLoss < meanRunningLoss:
stoppingCounter += 1
if stoppingCounter == 100:
print("model converged, early stopping")
# navigate/create order structure
path = pathOrigin + "/results"
os.chdir(path)
os.makedirs(modelName, exist_ok=True)
os.chdir(path + "/" + modelName)
checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
saveCheckpoint(checkpoint, modelName)
# save losses
dict = {"trainLoss": trainLosses, "validationLoss": [np.NaN for x in range(len(trainLosses))]}
trainResults = pd.DataFrame(dict)
# fill in validation losses with index
for i in range(len(validationLosses)):
trainResults.iloc[validationLosses[i, 1], 1] = validationLosses[i, 0]
# save dartaFrame to csv
trainResults.to_csv("resultsTraining.csv")
return
lastLoss = meanRunningLoss
"""
#print("epoch: ", x, ", example: ", trainCounter, " current loss = ", meanRunningLoss)
print("epoch: ", x, ", example: ", trainCounter, " current loss = ", loss.item())
# save memory
del loss, forward, helper, y
path = pathOrigin + "/results" ## check if takes global variable
os.chdir(path)
os.makedirs(modelName, exist_ok = True)
os.chdir(path + "/" + modelName)
## save model anyways in case it did not converge
checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
saveCheckpoint(checkpoint, modelName)
"""
# save losses
dict = {"trainLoss": trainLosses,
"validationLoss" : [np.NaN for x in range(len(trainLosses))],
"latentSpaceLoss": runningLossLatentSpace,
"reconstructionLoss": runningLossReconstruction}
trainResults = pd.DataFrame(dict)
# fill in validation losses with index
for i in range(len(validationLosses)):
trainResults.iloc[int(validationLosses[i, 1].item()), 1] = validationLosses[i, 0].item()
# save dartaFrame to csv
trainResults.to_csv("resultsTrainingPatches.csv")
"""
print("results saved!")
return
I have this issue and I dont know where the code collect the the data to add and my memory usage increases. Could you please help me with that?
Here is my code:
def train_batch(model, optimizer, device, batch, labels):
model.train()
optimizer.zero_grad()
length = float(batch.size(0))
mu_x, log_var_x, mu_q, log_var_q, mu_r, log_var_r = model(batch,labels)
kl_loss_b = KL(mu_r,log_var_r,mu_q,log_var_q)
L_loss_b = log_lik(labels, mu_x, log_var_x)
#print("Size of list3: " + str(sys.getsizeof(train_losses)) + "bytes")
L_loss = torch.sum(L_loss_b)
kl_loss = torch.sum(kl_loss_b)
loss = -(L_loss - kl_loss)/length
loss.backward()
# update the weights
optimizer.step()
# add for validation
return loss, kl_loss/length, L_loss/length
def trainv(model, device, epochs, train_iterator, optimizer, validate_iterator):
n_samples = 100
train_losses, kl_losses, lik_losses, test_losses = [], [],[], []
for epoch in range(epochs):
ep_tr, ep_kl,ep_l, num_batch, iterator = 0,0,0, 0, 0
for local_batch, local_labels in train_iterator:
local_batch, local_labels = local_batch.to(device), local_labels.to(device)
train_loss, kl_loss, lik_loss = train_batch(model, optimizer, device, local_batch,local_labels)
ep_tr += train_loss
ep_kl += kl_loss
ep_l += lik_loss
num_batch+= 1
iterator+=1
del local_batch, local_labels
train_losses.append(ep_tr/num_batch)
run_validate_flag = 0
if run_validate_flag ==1:
samples, truths, test_loss = runs_for_validate(validate_iterator, n_samples)
test_losses.append(-test_loss)
else:
test_losses = f'run_validate_flag;{0}'
return train_losses, test_losses
You are accumulating the losses inplace as described here, which will increase the memory usage.
Thank you. I tried to del the loss in
train_batch
function, and I canât. I also tried to
del train_loss, kl_loss, lik_loss
in trainv function and again the memory increases. Do you know where I should delete the loss or use the .item() in which function. Thank you
Deleting the losses after they were already accumulated inplace into ep_*
wonât help as PyTorch will not be able to free the tensors anymore since they are now referenced.
You would need to detach()
the tensors before accumulating them or you would need to call item()
on them as described before.
Is that same for pytorch lightning? After each iteration cuda allocated memory increases âŚhalf of the epoch I have memory error!! With 36 Gig memory!!!
After trying the above (del, item(), etc.), I still could not get rid of the problem.
What worked for me was to NOT DO this (DEVICE is cuda):
#torch.set_default_device(DEVICE) # this causes memory creep issues - DO NOT USE
Ever since I disabled that, memory has remained constant.
It has been a while from the last activity in this answer. But, I wonder why we need to explicitly call del because in the next iteration are not these tensors overwritten by the new tensors, so gc can collect that unreferenced tensors?