Don't know why the memory exploded

Hi, I am new to pytorch and I am practicing the sequence to sequence model according to the official tutorial provided by pytorch. Everything worked fine when I run the original tutorial. But when I used my own way to preprocess data but remained the encoder and decoder and training procedure the same, my colab always crashes due to the RAM overflow no matter how I cut the size of hidden units or reducing training examples (and I believe the training examples cannot be the issue since they are small enough compared to a usual one). I just have no idea about how to debug this, every suggestion or help will be appreciated!

The dataset I used is exactly the same as that in the tutorial, here is a link for the convenience:
https://drive.google.com/file/d/1yE8Wm1drqYm4Ogplsrym499PcTn3ji4j/view?usp=sharing

import torch
import torch.nn as nn
import re;
from nltk import word_tokenize
import nltk
import numpy as np;
from torch import optim
import random
import json
import os
# from classObj import Language, EncoderRNN, DecoderRNN;
import pickle
nltk.download('punkt');
from pylab import *;

device = 'cuda' if torch.cuda.is_available() else 'cpu';

class Language:
    def __init__(self, name):
        self.name = name;
        self.word2idx = {};
        self.idx2word = {0:'<s>', 1:'</s>'};
        self.wordcount = {};
        self.n_words = 2;

    def addSentence(self, sentence):
        for word in sentence:
            if word in self.word2idx:
                self.wordcount[word] += 1;
            else:
                self.word2idx[word] = len(self.idx2word);
                self.idx2word[len(self.idx2word)] = word;
                self.wordcount[word] = 1;
                self.n_words += 1;

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

#For this cell reference: Sean Robertson <https://github.com/spro/practical-pytorch>
import time;
import math

def asMinutes(s):
    m = math.floor(s/60);
    s -= m*60;
    return '%dm %ds' % (m, s);

def timeSince(since, percent):
    now = time.time();
    s = now - since;
    es = s/percent;
    rs = es - s;
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs));

import matplotlib.pyplot as plt;
import numpy as np;
import matplotlib.ticker as ticker;

def showPlot(points):
    plt.figure();
    fig, ax = plt.subplots();
    loc = ticker.MultipleLocator(base=0.2);
    ax.yaxis.set_major_locator(loc);
    plt.plot(points);
    plt.show()

device = 'cuda' if torch.cuda.is_available() else 'cpu';
print(device)
SOS = 0;
EOS = 1;

def readLangs(lang1, lang2, reverse=False):
    pairs = None;
    if os.path.isfile('training_pairs.txt'):
        with open('training_pairs.txt', 'r', encoding='utf-8') as f:
            pairs = json.load(f);
        print('Training data loaded!');
    else:
        print('Reading %s and %s corpus' % (lang1, lang2));
        lines = open('%s-%s.txt'%(lang1, lang2), 'r', encoding='utf-8').read().strip().lower().split('\n');
        pairs = [[word_tokenize(sentence) for sentence in l.split('\t')] for l in lines];
        with open('training_pairs.txt', 'w', encoding='utf-8') as f:
            json.dump(pairs, f);
        print('Training data saved!');
        # exit(0);
    if reverse:
        pairs = [reversed(p) for p in pairs];
        input_lang = Language(lang2);
        output_lang = Language(lang1);
    else:
        input_lang = Language(lang1);
        output_lang = Language(lang2);
    return input_lang, output_lang, pairs;

def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse);
    print('Loading corpus for both languages...');
    if os.path.isfile('Language1.class'):
        try:
            with open('Language1.class', 'rb') as f:
                input_lang = pickle.load(f);
            print('%s class loaded successfully!' % lang1);
            with open('Language2.class', 'rb') as f:
                output_lang = pickle.load(f);
            print('%s class loaded successfully!' % lang2);
        except:
            print('Loading language classes failed :(');
            exit(1);
    else:
        i = 0
        for p in pairs:
            i += 1;
            if len(p) != 2:
              print(p, i);
            input_lang.addSentence(p[0]);
            output_lang.addSentence(p[1]);
        with open('Language1.class', 'wb') as f:
            pickle.dump(input_lang, f);
        with open('Language2.class', 'wb') as f:
            pickle.dump(output_lang, f);
        print('Languages loaded!');
    print('Language1: %s with %d words' % (input_lang.name, input_lang.n_words));
    print('Language2: %s with %d words' % (output_lang.name, output_lang.n_words));
    return input_lang, output_lang, pairs;

def pair2tensor(pair, lang1, lang2):
    input = pair[0];
    output = pair[1];
    input = [lang1.word2idx[word] for word in input];
    input.append(EOS);
    output = [lang2.word2idx[word] for word in output];
    output.append(EOS);
    input = torch.tensor(input, dtype=torch.long, device=device).view(-1, 1);
    output = torch.tensor(output, dtype=torch.long, device=device).view(-1, 1);
    return input, output;

input_lang, output_lang, pairs = prepareData('eng', 'fra');
pairs = [list(pair2tensor(pair, input_lang, output_lang)) for pair in pairs];
pairs = np.array(pairs);
train_set = pairs[:50000, 0];
train_label = pairs[:50000, 1];
print(pairs.nbytes);

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=10):
    encoder_hidden = encoder.hiddenInit();
    encoder_optimizer.zero_grad();
    decoder_optimizer.zero_grad();
    input_length = input_tensor.size(0);
    target_length = target_tensor.size(0);
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device);
    loss = 0;
    for ei in range(input_length):
        if ei >= max_length:
            break;
        encoder_output, hidden_cell = encoder(input_tensor[ei], encoder_hidden);
        encoder_outputs[ei] = encoder_output[0, 0];
    decoder_input = torch.tensor([[SOS]], device=device);
    decoder_hidden = encoder_hidden;
    use_teacher_forcing = True if random.random() < 0.5 else False;
    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden);
            # print(decoder_output, target_tensor[di]);
            loss += criterion(decoder_output[0], target_tensor[di]);
            decoder_input = target_tensor[di];
        # exit(0)
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden);
            topv, topi = decoder_output.topk(1);
            decoder_input = topi.squeeze().detach();
            loss += criterion(decoder_output[0], target_tensor[di]);
            if decoder_input.item() == EOS:
                break;
    loss.backward();
    encoder_optimizer.step();
    decoder_optimizer.step();
    return loss.item()/target_length;

def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    start = time.time();
    plot_losses = [];
    print_loss_total = 0;
    plot_loss_total = 0;
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate);
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate);
    criterion = nn.NLLLoss();
    for iter in range(1, n_iters+1):
        input_tensor = train_set[iter-1];
        target_tensor = train_label[iter-1];
        loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion);
        print_loss_total += loss;
        plot_loss_total += loss;
        if iter%print_every == 0:
            print_loss_avg = print_loss_total/print_every;
            print_loss_total = 0;
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter/n_iters), iter, iter/n_iters*100, print_loss_avg));
        if iter%plot_every == 0:
            plot_loss_avg = plot_loss_total/plot_every;
            plot_losses.append(plot_loss_avg);
            plot_loss_total = 0;
    # torch.save(encoder, '5000_encoder.model');
    # torch.save(decoder, '5000_decoder.model');
    showPlot(plot_losses);

def evaluate(encoder, decoder, pair, max_length=10):
    with torch.no_grad():
        input_tensor = pair[0];
        input_length = input_tensor.size(0);
        encoder_hidden = encoder.hiddenInit();
        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device);
        for ei in range(input_length):
            if ei >= max_length:
                break;
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden);
            encoder_outputs[ei] += encoder_output[0, 0];
        decoder_input = torch.tensor([[SOS]], device=device);
        decoder_hidden = encoder_hidden;
        decoded_words = [];
        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden);
            topv, topi = decoder_output.data.topk(1);
            if topi.item() == EOS:
                decoded_words.append('</s>');
                break;
            else:
                decoded_words.append(output_lang.idx2word[topi.item()]);
            decoder_output = topi.squeeze().detach();
        return decoded_words

def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(pairs);
        print('>', ' '.join(input_lang.idx2word[idx.item()] for idx in pair[0]));
        print('=', ' '.join(output_lang.idx2word[idx.item()] for idx in pair[1]));
        output_words = evaluate(encoder, decoder, pair);
        output_sentence = ' '.join(output_words);
        print('<', output_sentence);
        print('');

hidden_size = 64;
print('\nUsing device');
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device);
decoder1 = DecoderRNN( output_lang.n_words, hidden_size).to(device);
trainIters(encoder1, decoder1, 50000, print_every=1000);
torch.save(encoder1, '50000_encoder.model');
torch.save(decoder1, '50000_decoder.model');
evaluateRandomly(encoder1, decoder1);