Hi, I am trying to use autograd.grad
function with create_graph. but i found when i set create_graph=True
, the memory grows up every iterations.
here is my code below, am i doing wrong?
thanks!
for iteration, (data1, data2, data3, data4) in enumerate(zip(self.train_loader1, self.train_loader2, self.train_loader3, self.train_loader4)):
## Initial iterations
steps = epoch*len(self.train_loader1) + iteration
if steps % self.args.train.disp_interval == 0:
start = time.time()
torch.cuda.empty_cache()
self.model_s.zero_grad()
self.model.zero_grad()
datalst = [data1, data2, data3, data4]
random.shuffle(datalst)
data = [*datalst[0], *datalst[1]]
data_mte = [*datalst[2]]
sd = self.model.state_dict()
self.model_s.load_state_dict(sd)
# Load data
images, targets = ship_data_to_cuda(data, self.device)
####! Simulate
#! meta train
# Pass data to model
loss_dict_s = self.model_s(images, targets)
# simul
losses_s = self.args.train.w_RPN_loss_cls * loss_dict_s['loss_objectness'] \
+ self.args.train.w_RPN_loss_box * loss_dict_s['loss_rpn_box_reg'] \
+ self.args.train.w_RCNN_loss_bbox * loss_dict_s['loss_box_reg'] \
+ self.args.train.w_RCNN_loss_cls * loss_dict_s['loss_detection'] \
+ self.args.train.w_OIM_loss_oim * loss_dict_s['loss_reid'] \
+ loss_dict_s['loss_occ'] \
+ loss_dict_s['loss_occ_bg']
# model_s update
grad_mtr = torch.autograd.grad(losses_s, [p for p in self.model_s.params() if p.requires_grad == True], create_graph=True)
self.model_s.update_params(lr_inner = self.optimizer.param_groups[0]['lr'], source_params = grad_mtr, solver = 'sgd')
#! meta test
self.model_s.zero_grad()
images_mte, targets_mte = ship_data_to_cuda(data_mte, self.device)
loss_dict_mte = self.model_s(images_mte, targets_mte)
losses_mte = self.args.train.w_RPN_loss_cls * loss_dict_mte['loss_objectness'] \
+ self.args.train.w_RPN_loss_box * loss_dict_mte['loss_rpn_box_reg'] \
+ self.args.train.w_RCNN_loss_bbox * loss_dict_mte['loss_box_reg'] \
+ self.args.train.w_RCNN_loss_cls * loss_dict_mte['loss_detection'] \
+ self.args.train.w_OIM_loss_oim * loss_dict_mte['loss_reid'] \
+ loss_dict_mte['loss_occ'] \
+ loss_dict_mte['loss_occ_bg']
loss_final = 0.8 * losses_s + 0.2 * losses_mte
grad = torch.autograd.grad(loss_final, [p for p in self.model_s.params() if p.requires_grad == True])
self.model.update_params(lr_inner = self.optimizer.param_groups[0]['lr'], source_params = grad, solver = 'sgd')
del grad_mtr, grad
## Post iteraions
if epoch == 0 and self.args.train.lr_warm_up:
sub_scheduler.step()
if steps % self.args.train.disp_interval == 0:
# Print
loss_value = losses_s.item()
state = dict(loss_value=loss_value,
lr=self.optimizer.param_groups[0]['lr'])
state.update(loss_dict_s)
# Update logger
batch_time = time.time() - start
metric_logger.update(batch_time=batch_time)
metric_logger.update(**state)
# Print log on console
metric_logger.print_log(epoch, iteration, len(self.train_loader1))
else:
state = None