Memory leakage caused by autograd.grad(create_graph=True)

Hi, I am trying to use autograd.grad function with create_graph. but i found when i set create_graph=True, the memory grows up every iterations.

here is my code below, am i doing wrong?

thanks!

for iteration, (data1, data2, data3, data4) in enumerate(zip(self.train_loader1, self.train_loader2, self.train_loader3, self.train_loader4)):
                ## Initial iterations
                steps = epoch*len(self.train_loader1) + iteration
                if steps % self.args.train.disp_interval == 0:
                    start = time.time()
                torch.cuda.empty_cache()
                
                self.model_s.zero_grad()
                self.model.zero_grad()

                datalst = [data1, data2, data3, data4]
                random.shuffle(datalst)

                data = [*datalst[0], *datalst[1]]
                data_mte = [*datalst[2]]

                sd = self.model.state_dict()
                self.model_s.load_state_dict(sd)

                # Load data
                images, targets = ship_data_to_cuda(data, self.device)

                ####! Simulate
                #! meta train 
                # Pass data to model
                loss_dict_s = self.model_s(images, targets)

                # simul
                losses_s = self.args.train.w_RPN_loss_cls * loss_dict_s['loss_objectness'] \
                    + self.args.train.w_RPN_loss_box * loss_dict_s['loss_rpn_box_reg'] \
                    + self.args.train.w_RCNN_loss_bbox * loss_dict_s['loss_box_reg'] \
                    + self.args.train.w_RCNN_loss_cls * loss_dict_s['loss_detection'] \
                    + self.args.train.w_OIM_loss_oim * loss_dict_s['loss_reid'] \
                    + loss_dict_s['loss_occ'] \
                    + loss_dict_s['loss_occ_bg']

                # model_s update
                grad_mtr = torch.autograd.grad(losses_s, [p for p in self.model_s.params() if p.requires_grad == True], create_graph=True)
                self.model_s.update_params(lr_inner = self.optimizer.param_groups[0]['lr'], source_params = grad_mtr, solver = 'sgd')
                
                #! meta test
                self.model_s.zero_grad()
                images_mte, targets_mte = ship_data_to_cuda(data_mte, self.device)

                loss_dict_mte = self.model_s(images_mte, targets_mte)
                losses_mte = self.args.train.w_RPN_loss_cls * loss_dict_mte['loss_objectness'] \
                    + self.args.train.w_RPN_loss_box * loss_dict_mte['loss_rpn_box_reg'] \
                    + self.args.train.w_RCNN_loss_bbox * loss_dict_mte['loss_box_reg'] \
                    + self.args.train.w_RCNN_loss_cls * loss_dict_mte['loss_detection'] \
                    + self.args.train.w_OIM_loss_oim * loss_dict_mte['loss_reid']  \
                    + loss_dict_mte['loss_occ'] \
                    + loss_dict_mte['loss_occ_bg']

                loss_final = 0.8 * losses_s + 0.2 * losses_mte

                grad = torch.autograd.grad(loss_final, [p for p in self.model_s.params() if p.requires_grad == True])
                self.model.update_params(lr_inner = self.optimizer.param_groups[0]['lr'], source_params = grad, solver = 'sgd')

                del grad_mtr, grad
                
                ## Post iteraions
                if epoch == 0 and self.args.train.lr_warm_up:
                    sub_scheduler.step()

                if steps % self.args.train.disp_interval == 0:
                    # Print 
                    loss_value = losses_s.item()
                    state = dict(loss_value=loss_value,
                                lr=self.optimizer.param_groups[0]['lr'])
                    state.update(loss_dict_s)

                    # Update logger
                    batch_time = time.time() - start
                    metric_logger.update(batch_time=batch_time)
                    metric_logger.update(**state)
                        
                    # Print log on console
                    metric_logger.print_log(epoch, iteration, len(self.train_loader1))
                else:
                    state = None