Train with nn.DistributedDataParallel, gpu0 will occupy much more memory memory

I train with nn.DistributedDataParallel

self.load_state(conf, fixed_str="accuracy0.9960_step160000.pth", merge_bn=False)
when i load pretrained model, GPU0 occupied much more memory,

|    2      8538      C   /usr/bin/python3                            8025MiB |
**|    2      8539      C   /usr/bin/python3                            1369MiB |**
|    3      8539      C   /usr/bin/python3                            8037MiB |

This is my code

from data.data_pipe import de_preprocess, get_train_loader, get_val_data
from data.dataset import data_prefetcher
# from model import Backbone, Arcface, MobileFaceNet, Am_softmax, l2_norm
from face_loss import Arcface, SoftmaxFace, SiamLoss, CenterLoss, AdaCos
# import resnet_varg as varg
import resnet_daisy as varg
from verifacation import evaluate
import torch
from torch import optim
import numpy as np
from tqdm import tqdm
from tensorboardX import SummaryWriter
from matplotlib import pyplot as plt
plt.switch_backend('agg')
from utils import get_time, gen_plot, separate_bn_paras  #hflip_batch
from PIL import Image
from torchvision import transforms as trans
import math
from pathlib import Path
import time
import sys
import mxnet as mx
import pickle
import cv2
from util_wqaq import *
import torch.distributed as dist
from apex import amp
from apex.parallel import DistributedDataParallel
import torch.backends.cudnn as cudnn

class face_learner(object):
    def __init__(self, conf, inference=False):
        print(conf)

        dist.init_process_group(backend='nccl')
        # torch.cuda.synchronize()
        torch.cuda.set_device(conf.local_rank)
        # cudnn.benchmark = True
        # cudnn.deterministic = True

        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            # self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device)
            self.model = varg.Res_VarG().cuda()
            print('{}_{} model generated'.format(conf.net_mode, conf.net_depth))

        if not inference:
            self.milestones = conf.milestones

            self.loader, self.class_num, self.sample_num = get_train_loader(conf)

            # self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            # self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device)
            self.head = AdaCos(embedding_size=conf.embedding_size, classnum=self.class_num).cuda()

            print('two model heads generated')

            paras_only_bn, paras_wo_bn, paras_bias = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD([
                                    {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5},
                                    {'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4},
                                    {'params': paras_only_bn}
                                ], lr = conf.lr, momentum = conf.momentum)
            else:
                self.optimizer = optim.SGD([
                                    {'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 1e-5},
                                    {'params': paras_bias, 'lr' : 0},
                                    # {'params': [self.center.module.centers], 'lr': 0},
                                    {'params': paras_only_bn}
                                ], lr = conf.lr, momentum = conf.momentum)

            self.model_list, self.optimizer = amp.initialize([self.model] + [self.head], self.optimizer, opt_level="O1")

            self.model = DistributedDataParallel(self.model_list[0])
            self.head = DistributedDataParallel(self.model_list[1])

            self.load_state(conf, fixed_str="accuracy0.9960_step160000.pth", merge_bn=False)

            print(self.optimizer)

            print('optimizers generated')
            # self.board_loss_every = len(self.loader)//100
            self.board_loss_every = 100
            # self.evaluate_every = len(self.loader)//3
            self.evaluate_every = 10000
            # self.evaluate_every = 1
            # self.save_every = len(self.loader)
            self.save_every = 10000
            # self.save_every = 1000
            # self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(Path(self.loader.dataset.root).parent)
        else:
            self.threshold = conf.threshold
    

    def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False, merge_bn=False):
        if from_save_folder:
            save_path = str(conf.save_path)
        else:
            save_path = str(conf.model_path)

        self.model.module.load_state_dict(torch.load(save_path + '/model_{}'.format(fixed_str), map_location=torch.device('cpu')))

        if not model_only:
            self.head.module.load_state_dict(torch.load(save_path + '/head_{}'.format(fixed_str), map_location=torch.device('cpu')))
            self.head2.module.load_state_dict(torch.load(save_path + '/head2_{}'.format(fixed_str), map_location=torch.device('cpu')))
            # self.optimizer.load_state_dict(torch.load(save_path + '/optimizer_{}'.format(fixed_str), map_location=torch.device('cpu')))
        print("load {} {}".format(save_path, fixed_str))


    def train(self, conf, epochs):
        self.model.train()
        # self.model.eval()
        running_loss = 0.
        running_loss2 = 0.
        running_loss3 = 0.
        best_acc = 99.0
        total_rec_right = 0
        total_rec_right2 = 0
        accumulation_steps = 1
        icont2 = 0

        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()

            iters = self.sample_num // (conf.batch_size * conf.num_gpus)
            data_oter = iter(self.loader)

            for i in range(iters):
                 imgs1, labels1 = next(data_oter)


                imgs1 = imgs1.cuda(non_blocking=True)
                labels1 = labels1.cuda(non_blocking=True)

                self.optimizer.zero_grad()
                embeddings = self.model(imgs1)
 
                thetas = self.head(embeddings, labels1, 1)
                loss1 = conf.ce_loss(thetas, labels1)

                loss = loss1
                loss.backward()
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                running_loss += loss1.item()

                if i % accumulation_steps == 0:
                    self.optimizer.step()
                    # self.optimizer.zero_grad()

                rec_right = (torch.argmax(thetas, axis=1) == labels1).cpu().numpy().sum()

                total_rec_right += rec_right
                if self.step % self.board_loss_every == 0 and self.step != 0:
                    training_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
                    loss_board = running_loss / self.board_loss_every

                    precision = total_rec_right / (self.board_loss_every * conf.batch_size)

                    print(""" {}: epoch {} step {}  loss_board:{:.4f}  loss_board2:{:.4f} acc:{:.4f} acc2:{:.4f}""".format(training_time,
                         e, self.step, loss_board, loss_board2, precision, precision2))
                    running_loss = 0.
                    total_rec_right = 0
                    # sys.stdout.flush()

                accuracy= 0.0
                if self.step % self.evaluate_every == 0 and self.step != 0:
                    accuracy, best_threshold = self.evaluate(conf, conf.data_path + '/agedb_30.bin')
                    self.board_val('agedb_30', accuracy, best_threshold)
                    accuracy, best_threshold = self.evaluate(conf, conf.data_path + '/cfp_fp.bin')
                    self.board_val('cfp_fp', accuracy, best_threshold)
                    accuracy, best_threshold = self.evaluate(conf, conf.data_path + '/lfw.bin')
                    self.board_val('lfw', accuracy, best_threshold)
                    self.model.train()

                if (self.step % self.save_every == 0 and self.step != 0) or (accuracy > best_acc):
                    self.save_state(conf, accuracy)
                    if accuracy > best_acc:
                        best_acc = accuracy
                        print("lfw best acc {}".format(accuracy))

                self.step += 1

        accuracy, best_threshold = self.evaluate(conf, conf.data_path + '/agedb_30.bin')
        self.board_val('agedb_30', accuracy, best_threshold)
        accuracy, best_threshold = self.evaluate(conf, conf.data_path + '/cfp_fp.bin')
        self.board_val('cfp_fp', accuracy, best_threshold)
        accuracy, best_threshold = self.evaluate(conf, conf.data_path + '/lfw.bin')
        self.board_val('lfw', accuracy, best_threshold)

        self.save_state(conf, accuracy, extra='final')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)