OOM error on using multiple pytorch datasets on a model

I’m training an LSTM-RNN to do machine translation on Multi30k dataset using torchtext. The dataset seem to contain multiple files to train and validate but the default parameters only load one file each for training and validation for each language.
When I iterate over the dataset with default params I can train my model with no problems for 30 epochs.
When I created multiple datasets, each iterating over a separate file. After a few iterations of the second dataset, Pytorch throws OOM error.

code for multi dataset epoch, where an epoch of training on every file constitutes a full epoch:

def full_data_epoch(encoder, decoder, device, enc_optim, dec_optim, sos_tok, src_vocab, trg_vocab, data_dir=".data"):
    src_train_files, trg_train_files, src_val_files, trg_val_files = get_file_names(data_dir)

    for src_train_file, trg_train_file, src_val_file, trg_val_file in zip(src_train_files, trg_train_files, src_val_files, trg_val_files):
        train_dataset, valid_dataset, test_dataset = Multi30k(
                                                train_filenames = (str(src_train_file), str(trg_train_file)),
                                                valid_filenames = (str(src_val_file), str(trg_val_file)),
                                                data_select = ("train", "val"),
                                                vocab= (src_vocab, trg_vocab))
        batch_size = 8
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                                   collate_fn=col_func(), drop_last=False)
        val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, 
                                                collate_fn=col_func(), drop_last=False)

        train_loss, train_acc = loader_epoch(encoder, decoder, train_loader, device, 
                                            loss_func, optims=[enc_optim, dec_optim], sos_tok=sos_tok)
        val_loss, val_acc = loader_epoch(encoder, decoder, val_loader, device, 
                                        loss_func, optims=None, sos_tok=sos_tok)

        print(f"train loss:{train_loss}\tval loss:{val_loss}")
        print(f"train accuracy:{train_acc}\tval accuracy:{val_acc}")
        with torch.no_grad():

main training:

device = "cuda"
# Scan all files in the data directory to create the vocab
src_vocab, trg_vocab = get_vocab(src_tok=de_tok, trg_tok=en_tok, data_dir=".data")
sos_tok = torch.tensor([src_vocab.stoi["<sos>"]])
sos_tok = sos_tok[None, :]

encoder = RNNModel(len(src_vocab.stoi), 300, 512, 0.5)
decoder = RNNModel(len(trg_vocab.stoi), 300, 512, 0.5, 0.5, "batch")
encoder = encoder.to(device)
decoder = decoder.to(device)

loss_func = torch.nn.CrossEntropyLoss()
enc_optim = torch.optim.Adam(encoder.parameters(), lr=1e-4)
dec_optim = torch.optim.Adam(decoder.parameters(), lr=1e-4)

num_epochs = 30
for epoch in tqdm(range(num_epochs)):
    full_data_epoch(encoder, decoder, device, enc_optim, dec_optim, sos_tok, src_vocab, trg_vocab)


CUDA out of memory. Tried to allocate 534.00 MiB (GPU 0; 3.00 GiB total capacity; 1.60 GiB already allocated; 408.52 MiB free; 1.63 GiB reserved in total by PyTorch)

EDIT: loader epoch trains the model on the given loader for 1 epoch

def model_iter(encoder, decoder, inp_batch, out_batch, loss_func, sos_tok):
    enc_out, enc_hidden, enc_cell = encoder(inp_batch)

    # dec_hidden = enc_hidden
    # dec_cell = enc_cell
    # print(out_batch.shape)
    sos_tok = sos_tok.expand(1, inp_batch.shape[1]).to(device)
    out_batch_in = torch.cat([sos_tok, out_batch], dim=0)
    dec_out, _, _ = decoder(out_batch_in, enc_hidden, enc_cell)
    dec_out = dec_out[:-1].reshape(-1, dec_out.shape[-1])
    out_batch = out_batch.reshape(-1)
    with torch.no_grad():
        dec_acc = dec_out.detach()
        out_acc = out_batch.detach()
        top_v, top_i = torch.topk(dec_acc, 1, dim=-1)
        top_i = top_i.squeeze()
        total = top_i.shape[0]
        # print(f"top:{top_i.shape[0]}\tout:{out_batch.shape}")
        diff = (top_i-out_acc)
        diff[diff != 0] = 1
        # print(f"diff:{diff.nonzero()}")
        wrong = diff[diff.nonzero(as_tuple=True)].sum()
        # print(f"wrong:{wrong}")
        correct = total - wrong
        # print(f"correct:{correct}\ttotal:{total}")
        acc = float(correct)/float(total)

    return loss_func(dec_out, out_batch), acc

def loader_epoch(encoder, decoder, loader, device, loss_func, optims=None, sos_tok=None):
    if optims is None:

    total_loss = 0
    accuracy = 0
    for idx, (inp_batch, out_batch) in tqdm(enumerate(loader), total=(len(loader))):

        inp_batch = inp_batch.to(device)
        out_batch = out_batch.to(device)

        # if idx == len(loader)-1:
        #     set_trace()

        if optims is None:
            with torch.no_grad():
                loss, acc = model_iter(encoder, decoder, inp_batch, out_batch, loss_func, sos_tok)    
            loss, acc = model_iter(encoder, decoder, inp_batch, out_batch, loss_func, sos_tok)    

        total_loss += loss.item()
        accuracy += acc

        if optims is not None:
            for optim in optims:

            for optim in optims:

    return total_loss / idx, accuracy / idx

I’m training on jupyter notebook, could that be a potential issue?

A couple of things:

  1. What exactly is happening inside loader_epoch?
  2. I was wondering if these lines can be moved inside full_data_epoch function.

The things that could have gone wrong:

  1. The vocab size for one of the files is too large to fit in memory.
  2. The model is dynamically modified during training which becomes to large to fit in memory.


  1. It takes the encoder and decoder and requisite dataloader and trains for one epoch.
  2. Sure they can be, I don’t know what use they will be though.


  1. I’m creating one vocab file for all the files in one go. That’s what get_vocab() does. I think a we need a single vocab file otherwise part of the model architecture would have to be changed. So if we can get through one epoch then the vocab files should not be an issue.
  2. From point 1 Model is not changing dynamically.

On further testing, it seems that files used to create the dataset maybe the issue. I created a dataset by manually entering in the file names to be used for validation and test split. Any file which is not the default file used to create the dataset tends to throw OOM error. In other words only
train.de, val.de, train.en, val.en
are the files that can be loaded and trained on.

Lastly It doesn’t matter if we create the vocab or it’s created automatically with the creation of dataset.