cuda out of memory

i have written this code and as the training process goes on, the GPU memory usage just becoming larger and larger, until out of memory.I’ve located the problem in the function train(),when i use the same batch in all epochs, there won’t be any problem,but if i shuffle the data and create new batches with the same data, the out of memory error happens.I’ve try torch.cuda.empty_cache() but it doesn’t help.can somebody helps me?
dataloder

import os
import jieba
import pickle
import random
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
padToken, goToken, eosToken, unknownToken = 0, 1, 2, 3


class Batch:
    # batch类,里面包含了encoder输入,decoder输入以及他们的长度
    def __init__(self):
        self.encoder_inputs = None
        self.encoder_inputs_length = None
        self.decoder_targets = None
        self.decoder_targets_length = None
        self.mask_t = None
        self.mask_s = None


def loadDataset(filename):
    """
    :param filename: 数据的路径,数据是一个json结构,包含三部分,分别是word2id,即word到id的转换,
    id2word,即id到word的转换 ,以及训练数据trainingSamples,是一个二维数组,形状为N*2,每一行包含问题和回答
    :return: 通过pickle解析我们的数据,返回上述的三部分内容。
    """
    dataset_path = os.path.join(filename)
    print('Loading dataset from {}'.format(dataset_path))
    with open(dataset_path, 'rb') as handle:
        data = pickle.load(handle)
        word2id = data['word2id']
        id2word = data['id2word']
        train_samples = data['train_samples']
        val_samples = data['val_samples']
        test_samples = data['test_samples']
        pretrain_embedding = data['pretrain_embedding']
    return word2id, id2word, pretrain_embedding, train_samples, val_samples, test_samples

def by_score(t):
    return len(t[0])

def createBatch(samples):
    '''
    根据给出的samples(就是一个batch的数据),进行padding并构造成placeholder所需要的数据形式
    :param samples: 一个batch的样本数据,列表,每个元素都是[question, answer]的形式,id
    :return: 处理完之后可以直接传入feed_dict的数据格式
    '''
    batch = Batch()
    _samples = sorted(samples, key=by_score, reverse=True)
    encoder_inputs_length = [len(sample[0]) for sample in _samples]
    decoder_targets_length = [len(sample[1]) + 1 for sample in _samples]
    batch.encoder_inputs_length = torch.LongTensor(encoder_inputs_length)
    # 模型训练时要加eos,所以这里targrt——length要加1
    batch.decoder_targets_length = torch.LongTensor(decoder_targets_length)

    max_source_length = max(encoder_inputs_length)
    max_target_length = max(decoder_targets_length)
    encoder_inputs = []
    decoder_targets = []
    mask_t = []
    mask_s = []
    for  index in range(len(_samples)):
        # 将source进行反序并PAD值本batch的最大长度
        # source = list(reversed(sample[0]))
        pad = [padToken] * (max_source_length - len(_samples[index][0]))
        encoder_inputs.append(_samples[index][0] + pad)
        mask_s.append([1] * encoder_inputs_length[index] + [0] * (max_source_length - encoder_inputs_length[index]))

        # 将target进行PAD,并添加END符号
        target = _samples[index][1]
        target.append(2)
        pad = [padToken] * (max_target_length - decoder_targets_length[index])
        decoder_targets.append(target + pad)
        mask_t.append([1] * decoder_targets_length[index] + [0] * (max_target_length - decoder_targets_length[index]))
        # batch.target_inputs.append([goToken] + target + pad[:-1])
    batch.encoder_inputs = torch.LongTensor(encoder_inputs)
    batch.decoder_targets = torch.LongTensor(decoder_targets)
    batch.mask_t = torch.ByteTensor(mask_t)
    batch.mask_s = torch.ByteTensor(mask_s)

    return batch


def getBatches(data, batch_size):
    '''
    根据读取出来的所有数据和batch_size将原始数据分成不同的小batch。对每个batch索引的样本调用createBatch函数进行处理
    :param data: loadDataset函数读取之后的trainingSamples,就是QA对的列表
    :param batch_size: batch大小
    :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
    :return: 列表,每个元素都是一个batch的样本数据,可直接传入feed_dict进行训练
    '''
    # 每个epoch之前都要进行样本的shuffle
    random.shuffle(data)
    batches = []
    data_len = len(data)

    def genNextSamples():
        for i in range(0, data_len, batch_size):
            yield data[i:min(i + batch_size, data_len)]

    for samples in genNextSamples():
        batch = createBatch(samples)
        batches.append(batch)
    return batches


def sentence2enco(sentence, word2id):
    '''
    测试的时候将用户输入的句子转化为可以直接feed进模型的数据,现将句子转化成id,然后调用createBatch处理
    :param sentence: 用户输入的句子
    :param word2id: 单词与id之间的对应关系字典
    :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
    :return: 处理之后的数据,可直接feed进模型进行预测
    '''
    if sentence == '':
        return None
    # 分词
    tokens = [word for word in jieba.cut(sentence)]
    # tokens = sentence
    if len(tokens) > 20:
        return None
    # 将每个单词转化为id
    wordIds = []
    for token in tokens:
        wordIds.append(word2id.get(token, unknownToken))
    # 调用createBatch构造batch
    batch = createBatch([[wordIds, []]])
    return batch


train code

import torch
import random
from dataloader import loadDataset, getBatches, sentence2enco
from tqdm import tqdm
import math
import torch.optim as optim
import os
from config import config
from model import Encoder, AttnDecoder, MLP, Embedding
import time
import numpy as np

torch.backends.cudnn.enabled = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_forcing_ratio = 0.5
SOS_token = 1

config = config()


def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    return loss, nTotal.item()

def createOptimizer(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, optimizer_type, learning_rate, weight_decay,iter):
    if optimizer_type == 'Adam':
        encoder_s_opt = optim.Adam(encoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_s_opt = optim.Adam(decoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        encoder_t_opt = optim.Adam(encoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_t_opt = optim.Adam(decoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        mlp_opt = optim.Adam(mlp.parameters(), lr=learning_rate, weight_decay=weight_decay)
        emb_opt = optim.Adam(embedding.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:
        encoder_s_opt = optim.SGD(encoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_s_opt = optim.SGD(decoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        encoder_t_opt = optim.SGD(encoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_t_opt = optim.SGD(decoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        mlp_opt = optim.SGD(mlp.parameters(), lr=learning_rate, weight_decay=weight_decay)
        emb_opt = optim.SGD(embedding.parameters(), lr=learning_rate, weight_decay=weight_decay)

    if iter >= config._after - 1:
        for param_group in encoder_s_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in decoder_s_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in encoder_t_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in decoder_t_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in mlp_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in emb_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
    return encoder_s_opt, decoder_s_opt, encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt

def trainIters(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, optimizer_type, train_samples, val_samples,
               learning_rate, weight_decay):
    current_step = 0
    best_loss = 100.
    val_batches = getBatches(val_samples, config.batch_size)
    # train_batches = getBatches(train_samples, config.batch_size)
    for iter in range(config.numEpochs):
        # if (iter + 1 ) % 10 == 0:
        train_batches = getBatches(train_samples, config.batch_size)
        print("----- Epoch {}/{} -----".format(iter + 1, config.numEpochs))
        loss_fn = torch.nn.MSELoss(reduction='none')
        encoder_s_opt, decoder_s_opt, encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt = createOptimizer(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, optimizer_type, learning_rate, weight_decay,iter)
        for next_batch in train_batches:
            current_step += 1
            loss_train = train(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, encoder_s_opt, decoder_s_opt,
                                   encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt, next_batch, config.clip_max_norm, loss_fn)
            if current_step % config.steps_per_checkpoint == 0:
                total_loss = 0.
                total_per = 0.
                with torch.no_grad():
                    for nextBatch in val_batches:
                        _loss = val(encoder_s, decoder_t, mlp, embedding, nextBatch)
                        perplexity = math.exp(float(_loss)) if _loss < 50 else float('inf')
                        total_loss += _loss * len(nextBatch.encoder_inputs_length)
                        total_per += perplexity * len(nextBatch.encoder_inputs_length)
                val_loss = total_loss / len(val_samples)
                val_per = total_per / len(val_samples)
                tqdm.write("----- Step %d -- Loss_train %.4f -- Loss_test %.4f -- Time %s" % (current_step, loss_train, val_loss, time.strftime('%Y.%m.%d %H:%M:%S', time.localtime(time.time()))))
                if val_loss < best_loss:
                    best_loss = val_loss
                    torch.save(encoder_s, config.model_dir + '/encoder_s_val')
                    torch.save(decoder_s, config.model_dir + '/decoder_s_val')
                    torch.save(encoder_t, config.model_dir + '/encoder_t_val')
                    torch.save(decoder_t, config.model_dir + '/decoder_t_val')
                    torch.save(mlp, config.model_dir + '/mlp_val')
                    torch.save(embedding, config.model_dir + '/embedding_val')
                else:
                    torch.save(encoder_s, config.model_dir + '/encoder_s')
                    torch.save(decoder_s, config.model_dir + '/decoder_s')
                    torch.save(encoder_t, config.model_dir + '/encoder_t')
                    torch.save(decoder_t, config.model_dir + '/decoder_t')
                    torch.save(mlp, config.model_dir + '/mlp')
                    torch.save(embedding, config.model_dir + '/embedding')



def train(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, encoder_s_opt, decoder_s_opt,
                                   encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt, batch,
          clip_max_norm, loss_fn):
    encoder_s.train()
    decoder_s.train()
    encoder_t.train()
    decoder_t.train()
    mlp.train()
    embedding.train()
    encoder_s_opt.zero_grad()
    decoder_s_opt.zero_grad()
    encoder_t_opt.zero_grad()
    decoder_t_opt.zero_grad()
    mlp_opt.zero_grad()
    emb_opt.zero_grad()
    encoder_inputs = batch.encoder_inputs.to(device)
    encoder_inputs_length = batch.encoder_inputs_length.to(device)
    mask_s = batch.mask_s.to(device)
    decoder_targets = batch.decoder_targets.to(device)
    decoder_targets_length = batch.decoder_targets_length.to(device)
    mask_t = batch.mask_t.to(device)
    batch_size = len(batch.encoder_inputs)

    #计算J1(θ)=−logP(x̃ |x;θ)
    loss_1 = 0
    print_losses_1 = 0
    n_totals_1 = 0
    output_length = encoder_inputs.size()[1]
    encoder_output, encoder_hidden = encoder_s(embedding(encoder_inputs), encoder_inputs_length)
    decoder_hidden = encoder_hidden
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)
    use_teacher_forcing = True if random.random() < config.SCHEDULED_SAMPLING_RATIO else False
    if use_teacher_forcing:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_s(embedding(decoder_input), decoder_hidden, encoder_output)
            decoder_input = encoder_inputs[:, i].view(batch_size, 1)
            mask_loss_1, nTotal_1 = maskNLLLoss(logits, encoder_inputs[:, i], mask_s[:, i])
            loss_1 += mask_loss_1
            print_losses_1 += mask_loss_1.item() * nTotal_1
            n_totals_1 += nTotal_1
    else:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_s(embedding(decoder_input), decoder_hidden, encoder_output)
            topv, topi = logits.data.topk(1)
            ni = topi
            decoder_input = torch.LongTensor(ni).view(batch_size, 1).to(device)
            mask_loss_1, nTotal_1 = maskNLLLoss(logits, encoder_inputs[:, i], mask_s[:, i])
            loss_1 += mask_loss_1
            print_losses_1 += mask_loss_1.item() * nTotal_1
            n_totals_1 += nTotal_1

    #计算J2(φ)=−logP(ỹ|y;φ)
    loss_2 = 0
    print_losses_2 = 0
    n_totals_2 = 0
    output_length = decoder_targets.size()[1]
    encoder_output, encoder_hidden = encoder_t(embedding(decoder_targets), decoder_targets_length, use_pack=False)
    decoder_hidden = encoder_hidden
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)
    use_teacher_forcing = True if random.random() < config.SCHEDULED_SAMPLING_RATIO else False
    if use_teacher_forcing:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden, encoder_output)
            decoder_input = decoder_targets[:, i].view(batch_size, 1)
            mask_loss_2, nTotal_2 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_2 += mask_loss_2
            print_losses_2 += mask_loss_2.item() * nTotal_2
            n_totals_2 += nTotal_2
    else:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden, encoder_output)
            topv, topi = logits.data.topk(1)
            ni = topi
            decoder_input = torch.LongTensor(ni).view(batch_size, 1).to(device)
            mask_loss_2, nTotal_2 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_2 += mask_loss_2
            print_losses_2 += mask_loss_2.item() * nTotal_2
            n_totals_2 += nTotal_2

    #计算J3(γ)= 1∥t−s∥2
    _, s = encoder_s(embedding(encoder_inputs), encoder_inputs_length)
    _, t = encoder_t(embedding(decoder_targets), decoder_targets_length, use_pack=False)

    loss_3 = torch.sum(loss_fn(mlp(s), t))

    #计算J4(θ,φ,γ) = −logP(y|x;θ,φ,γ)
    loss_4 = 0
    print_losses_4 = 0
    n_totals_4 = 0
    output_length = decoder_targets.size()[1]
    encoder_output, encoder_hidden = encoder_s(embedding(encoder_inputs),encoder_inputs_length)
    decoder_hidden = mlp(encoder_hidden)
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)
    use_teacher_forcing = True if random.random() < config.SCHEDULED_SAMPLING_RATIO else False
    if use_teacher_forcing:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden, encoder_output)
            decoder_input = decoder_targets[:, i].view(batch_size, 1)
            mask_loss_4, nTotal_4 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_4 += mask_loss_4
            print_losses_4 += mask_loss_4.item() * nTotal_4
            n_totals_4 += nTotal_4
    else:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden, encoder_output)
            topv, topi = logits.data.topk(1)
            ni = topi
            decoder_input = torch.LongTensor(ni).view(batch_size, 1).to(device)
            mask_loss_4, nTotal_4 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_4 += mask_loss_4
            print_losses_4 += mask_loss_4.item() * nTotal_4
            n_totals_4 += nTotal_4

    loss = loss_1 + loss_2 + 0.01 * loss_3 + loss_4
    loss.backward()
    torch.nn.utils.clip_grad_norm_(encoder_s.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(decoder_s.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(encoder_t.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(decoder_t.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(mlp.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(embedding.parameters(), max_norm=clip_max_norm)
    decoder_s_opt.step()
    encoder_s_opt.step()
    decoder_t_opt.step()
    encoder_t_opt.step()
    mlp_opt.step()
    emb_opt.step()


    return print_losses_4 / n_totals_4

def val(encoder, decoder, mlp, embedding, batch):
    encoder.eval()
    decoder.eval()
    mlp.eval()
    embedding.eval()
    encoder_inputs = batch.encoder_inputs.to(device)
    encoder_inputs_length = batch.encoder_inputs_length.to(device)
    decoder_targets = batch.decoder_targets.to(device)
    mask_t = batch.mask_t.to(device)
    batch_size = len(batch.encoder_inputs)
    output_length = decoder_targets.size()[1]
    encoder_output, encoder_hidden = encoder(embedding(encoder_inputs), encoder_inputs_length)
    decoder_hidden = mlp(encoder_hidden)
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)

    loss = 0
    print_losses = 0
    n_totals = 0

    for i in range(0, output_length):
        decoder_output, logits, _, decoder_hidden = decoder(embedding(decoder_input), decoder_hidden, encoder_output)
        decoder_input = decoder_targets[:, i].view(batch_size, 1)
        mask_loss, nTotal = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
        loss += mask_loss
        print_losses += mask_loss.item() * nTotal
        n_totals += nTotal
    return print_losses / n_totals

def build_model():
    data_path = config.data_path
    word2id, id2word, pretrain_embedding, train_samples, val_samples, test_samples = loadDataset(
        data_path)
    if os.path.exists(config.model_dir):
        print('Reloading model from ' + config.model_dir)
        encoder_s = torch.load(config.model_dir + '/encoder_s')
        decoder_s = torch.load(config.model_dir + '/decoder_s')
        encoder_t = torch.load(config.model_dir + '/encoder_t')
        decoder_t = torch.load(config.model_dir + '/decoder_t')
        mlp = torch.load(config.model_dir + '/mlp')
        embedding = torch.load(config.model_dir + '/embedding')
    else:
        print('Building model to ' + config.model_dir)
        os.mkdir(config.model_dir)
        encoder_s = Encoder(input_size=config.embedding_size,
                            hidden_size=config.cell_size,
                            drop_prob=config.keep_prob,
                            cell_type=config.cell_name,
                            nonlinearity=config.nonlinearity,
                            num_layers=config.num_layers,
                            bidirectional=config.bidirectional
                            )

        decoder_s = AttnDecoder(input_size=config.embedding_size,
                                output_size=len(word2id),
                                hidden_size=config.cell_size,
                                drop_prob=config.keep_prob,
                                cell_type=config.cell_name,
                                nonlinearity=config.nonlinearity,
                                num_layers=config.num_layers,
                                bidirectional=config.bidirectional,
                                attn=config.attn
                                )
        encoder_t = Encoder(input_size=config.embedding_size,
                            hidden_size=config.cell_size,
                            drop_prob=config.keep_prob,
                            cell_type=config.cell_name,
                            nonlinearity=config.nonlinearity,
                            num_layers=config.num_layers,
                            bidirectional=config.bidirectional
                            )

        decoder_t = AttnDecoder(input_size=config.embedding_size,
                                output_size=len(word2id),
                                hidden_size=config.cell_size,
                                drop_prob=config.keep_prob,
                                cell_type=config.cell_name,
                                nonlinearity=config.nonlinearity,
                                num_layers=config.num_layers,
                                bidirectional=config.bidirectional,
                                attn=config.attn
                                )
        mlp = MLP(config.cell_size, config.cell_size)
        embedding = Embedding(pretrain_embedding)

        torch.save(encoder_s.to(device), config.model_dir + '/encoder_s')
        torch.save(decoder_s.to(device), config.model_dir + '/decoder_s')
        torch.save(encoder_t.to(device), config.model_dir + '/encoder_t')
        torch.save(decoder_t.to(device), config.model_dir + '/decoder_t')
        torch.save(mlp.to(device), config.model_dir + '/mlp')
        torch.save(embedding.to(device), config.model_dir + '/embedding')
    return encoder_s.to(device), decoder_s.to(device), encoder_t.to(device), decoder_t.to(device), mlp.to(
        device), embedding.to(device), train_samples, val_samples, test_samples


encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, train_samples, val_samples, test_samples = build_model()
trainIters(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, config.optimizer_type, train_samples, val_samples,
           config.learning_rate, config.weight_decay)

未能找到问题,代码有点多,建议再进一步定位一下吃内存的具体代码片段。(suppose you understand chinese)