I have a problem. Memory increase almost double during training.
Here is my code. Please give me a solution about memory increases problems.
My code starts following function call order (ssl_train -> _train_epoches -> _train_batch). It seems to be the memory increasing problem where every time function named ‘train_epoches’ finishes. Thank you. @ptrblck_de, Could you help me?
def _train_batch(self, input_variable, input_length, target_variable, model):
logits, attn = model(input_variable, input_length)
loss = self.criterion(logits, target_variable)
self.optimizer.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
def _train_epoches(self, data, model, n_epochs, dev_data=None, test_data=None):
labeled_dataset = torchtext.data.Dataset(data, fields=[('text', self.TEXT_field), ('label', self.LABEL_field)])
label_batch_iter = torchtext.data.BucketIterator(dataset=labeled_dataset, batch_size=128,
sort_key=lambda x:len(x.text), sort_within_batch=True,
device=self.device, repeat=False, shuffle=True)
log = self.logger
early_stopping = EarlyStopping(patience = 3, verbose=True)
best_accuracy = 0
for epoch in range(0, n_epochs):
model.train()
loss_total = 0
step = 0
for batch in label_batch_iter:
input_variables, input_lengths = batch.text
target_variables = batch.label
loss = self._train_batch(input_variables, input_lengths.tolist(), target_variables, model)
loss_total += loss.item()
step +=1
del loss, batch
epoch_loss_avg = loss_total / step
log_msg = "Finished epoch %d: SSL Train %s: %.4f" % (epoch , 'Cross_Entropy', epoch_loss_avg)
with torch.no_grad():
if dev_data is not None:
model.eval()
dev_loss, dev_acc = self.evaluator.evaluate(model, dev_data)
self.dev_acc = dev_acc
log_msg += ", Dev %s: %.4f, Accuracy: %.4f" % ('Cross_Entropy', dev_loss, dev_acc)
log.info(log_msg)
early_stopping(dev_loss, model, self.optimizer, epoch, step, self.input_vocab, self.expt_dir)
print('dev_ early stopping', early_stopping.counter)
if self.dev_acc > best_accuracy:
best_accuracy = self.dev_acc
Checkpoint(model=model, optimizer= self.optimizer,
epoch=epoch, step=step, input_vocab=self.input_vocab).save(self.expt_dir +'/best_accuracy')
print('*'*100)
print('SAVE MODEL (BEST DEV ACC)')
if test_data is not None:
model.eval()
test_loss, accuracy = self.evaluator.evaluate(model, test_data)
log_msg += ", Test %s: %.4f, Accuracy: %.4f" % ('Cross_Entropy', test_loss, accuracy)
log.info(log_msg)
if early_stopping.early_stop:
print("Early Stopping")
checkpoint = Checkpoint.get_latest_checkpoint(self.expt_dir + '/best_accuracy')
checkpoint = Checkpoint.load(checkpoint)
model = checkpoint.model
# config
optimizer = checkpoint.optimizer
resume_optim = checkpoint.optimizer.optimizer
defaults = resume_optim.param_groups[0]
defaults.pop('params', None)
defaults.pop('initial_lr', None)
optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults)
self.optimizer = optimizer
loss, accuracy = self.evaluator.evaluate(model, test_data)
print('LOAD BEST ACCURACY MODEL ::: loss > {} accuracy{}'.format(loss, accuracy))
break
return model
def ssl_train(self, num_epochs=30, dev_data=None, test_data=None, methods=None, reverse_augment=False):
log = self.logger
self.device = torch.device('cuda:0') if torch.cuda.is_available() else -1
model = self.model
for epoch in range(num_epochs):
print('INFO. # current epoch:', epoch)
# load updated unlabeled examples dataset
unlabeled_dataset = torchtext.data.Dataset(self.unlabeled_examples,
fields = [('text', self.TEXT_field), ('label', self.LABEL_field)])
unlabel_batch_iter = torchtext.data.BucketIterator(dataset=unlabeled_dataset, batch_size=128,
sort_key=lambda x:len(x.text), sort_within_batch=True,
device=self.device, repeat=False, shuffle=False, train=False)
# predict unlabeled dataset
labeled_samples = []
# model evaluation
model.eval()
reverse_batch_iter = None
with torch.no_grad():
model.eval()
for idx, unlabel_batch in enumerate(unlabel_batch_iter):
str_list = itos(self.input_vocab, unlabel_batch.text[0])
str_list = [sent.replace('<pad>','').strip() for sent in str_list]
input_var = unlabel_batch.text[0]
input_len = unlabel_batch.text[1]
target_label = unlabel_batch.label
logits, attn = model(input_var, input_len.tolist())
if methods == 'self-training':
labeled_dataset = self.self_training_labeling(input_var, str_list, logits, attn, target_label)
elif methods == 'pseudo-label':
labeled_dataset = self.pseudo_labeling(input_var, str_list, logits, attn, target_label)
labeled_samples.extend(labeled_dataset)
labeled_samples = list(set(labeled_samples))
if reverse_augment is False:
labeled_samples = self.balancing_labeled_result(labeled_samples)
print('INFO. # num of new balacned dataset from pseudo label result', len(labeled_samples))
else:
self.save_labeled_result(self.outputs_dir+'/pseudo_label_epoch_{}.txt'.format(str(epoch)), labeled_samples)
reversed_examples = self.gen_reverse_examples(labeled_samples)
self.update_dataset(labeled_samples)
if reverse_augment is True:
self.labeled_examples.extend(reversed_examples)
print('ADD REVERSED_EXAMPLES {} -> {}'.format(len(reversed_examples),len(self.labeled_examples)))
model = self._train_epoches(self.labeled_examples, model, 30, dev_data, test_data)
if self.ssl_early_stopping_patience == self.ssl_early_stopping or len(self.unlabeled_examples) == 0:
print('SSL EARLY STOPPING EPOCH {}'.format(epoch))
break