How to reduce memory usage?

Hi everyone, I caculate loss from multi stage of hougrass net, I found that the memory usage of caculating loss is very big, so I can not train the model with batchsize greater than 4. The output of the model is list of tensor output from each stage, so I compute and sum up them with groud truth heatmaps. So how can I reduce memory usage by modifying my code? I mean can I caculate the loss with memory efficient? you can see my code below. thank you. note that the size of dense_map and weight is 34x96x96, and center_map is 1x96x96

def _train_epoch(self, epoch):
        self.model.module.train()
        train_loader_bar = tqdm(range(len(self.train_loader)),
                                desc="Epoch {}".format(epoch), position=0
                                )
        self.train_loss.reset()
        root_loss_sum = 0.0
        dense_loss_sum = 0.0
        for img, center_map, dense_map, weight in self.train_loader:
            img = img.to(self.device)
            center_map = center_map.to(self.device)
            dense_map = dense_map.to(self.device)
            weight = weight.to(self.device)
            center_pred, dense_pred = self.model(img)

            root_loss = self.root_criterion(center_pred[0], center_map)
            dense_loss = self.dense_criterion(
                                              dense_pred[0]*weight,
                                              dense_map*weight
                                              )
            total_loss = root_loss + self.dense_loss_weight * dense_loss
            root_loss_sum += root_loss.item()
            dense_loss_sum += dense_loss.item()
            for s in range(1, len(dense_pred)):
                root_loss = self.root_criterion(center_pred[s], center_map)
                dense_loss = self.dense_criterion(
                                                  dense_pred[s]*weight,
                                                  dense_map*weight
                                                  )
                total_loss = total_loss + root_loss \
                                        + self.dense_loss_weight * dense_loss
                root_loss_sum += root_loss.item()
                dense_loss_sum += dense_loss.item()
            self.train_loss.update(total_loss.item())
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            train_loader_bar.update(1)
        root_loss_avg = root_loss_sum / len(self.train_loader)
        dense_loss_avg = dense_loss_sum / len(self.train_loader)
        return root_loss_avg, dense_loss_avg, self.train_loss.avg