CUDA memory error if I don't call .backward()

I am getting CUDA out of memory error running following.

def train(self, epochs, saving_freq):
        assert self.model.train
        while self.epoch <= epochs:
            running_loss = {}

            for idx, batch_sample in enumerate(self.train_loader):
                print("Iteration:", idx)
                loss_dict = self.train_step(batch_sample)
                
                if idx%10==0:
                     with torch.no_grad():
                    #----------- Logging and Printing ----------#
                    print("{:<8d} {:<9d} {:<7.4f}".format(self.epoch, idx, loss_dict['tot_loss'].item()))
                    for loss_name, value in loss_dict.items():
                        self.tb_writer.add_scalar('Loss/'+loss_name, value.item())

where train_step is defined as:

    def train_step(self, batch_sample):
        in_images = batch_sample['image'].to(self.device)
        target = [x.to(self.device) for x in batch_sample['target']]

        rpn_proposals, instances, rpn_losses, detection_losses = self.model(in_images, target, self.is_training)
        print([len(x) for x in instances])
        
        loss_dict = {}
        loss_dict.update(rpn_losses)
        loss_dict.update(detection_losses)
        
        loss = 0.0
        for k, v in loss_dict.items():
            loss += v
        loss_dict.update({'tot_loss':loss})

        self.optimizer.zero_grad()
        # loss.backward()
        self.optimizer.step()

        return loss_dict

Weird thing is that if I do any of the following 3, the memory error goes away:

  1. uncomment loss.backward() - I understand that this frees memory graph, but even in my case, at each iteration, previous memory allocation should be freed itself.
  2. I change if idx%10==0: to idx%1==0 in function train() . ie instead of doing it every iteration intead of every 10 iteration
  3. If I comment this line:

for loss_name, value in loss_dict.items():
        self.tb_writer.add_scalar('Loss/'+loss_name, value.item())

So I guess above line is not allowing to free the memory. BTW, this line is just logging to tensorboard.

Let me know if you need an error log.