I’m training an LSTM-RNN to do machine translation on Multi30k dataset using torchtext. The dataset seem to contain multiple files to train and validate but the default parameters only load one file each for training and validation for each language.
When I iterate over the dataset with default params I can train my model with no problems for 30 epochs.
When I created multiple datasets, each iterating over a separate file. After a few iterations of the second dataset, Pytorch throws OOM error.
code for multi dataset epoch, where an epoch of training on every file constitutes a full epoch:
def full_data_epoch(encoder, decoder, device, enc_optim, dec_optim, sos_tok, src_vocab, trg_vocab, data_dir=".data"):
src_train_files, trg_train_files, src_val_files, trg_val_files = get_file_names(data_dir)
for src_train_file, trg_train_file, src_val_file, trg_val_file in zip(src_train_files, trg_train_files, src_val_files, trg_val_files):
train_dataset, valid_dataset, test_dataset = Multi30k(
train_filenames = (str(src_train_file), str(trg_train_file)),
valid_filenames = (str(src_val_file), str(trg_val_file)),
data_select = ("train", "val"),
vocab= (src_vocab, trg_vocab))
batch_size = 8
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
collate_fn=col_func(), drop_last=False)
val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size,
collate_fn=col_func(), drop_last=False)
train_loss, train_acc = loader_epoch(encoder, decoder, train_loader, device,
loss_func, optims=[enc_optim, dec_optim], sos_tok=sos_tok)
val_loss, val_acc = loader_epoch(encoder, decoder, val_loader, device,
loss_func, optims=None, sos_tok=sos_tok)
print(f"train loss:{train_loss}\tval loss:{val_loss}")
print(f"train accuracy:{train_acc}\tval accuracy:{val_acc}")
with torch.no_grad():
torch.cuda.empty_cache()
main training:
device = "cuda"
# Scan all files in the data directory to create the vocab
src_vocab, trg_vocab = get_vocab(src_tok=de_tok, trg_tok=en_tok, data_dir=".data")
sos_tok = torch.tensor([src_vocab.stoi["<sos>"]])
sos_tok = sos_tok[None, :]
encoder = RNNModel(len(src_vocab.stoi), 300, 512, 0.5)
decoder = RNNModel(len(trg_vocab.stoi), 300, 512, 0.5, 0.5, "batch")
encoder = encoder.to(device)
decoder = decoder.to(device)
loss_func = torch.nn.CrossEntropyLoss()
enc_optim = torch.optim.Adam(encoder.parameters(), lr=1e-4)
dec_optim = torch.optim.Adam(decoder.parameters(), lr=1e-4)
num_epochs = 30
print(len(src_vocab))
print(len(trg_vocab))
print("vocab_collected")
for epoch in tqdm(range(num_epochs)):
full_data_epoch(encoder, decoder, device, enc_optim, dec_optim, sos_tok, src_vocab, trg_vocab)
Error:
CUDA out of memory. Tried to allocate 534.00 MiB (GPU 0; 3.00 GiB total capacity; 1.60 GiB already allocated; 408.52 MiB free; 1.63 GiB reserved in total by PyTorch)
EDIT: loader epoch trains the model on the given loader for 1 epoch
def model_iter(encoder, decoder, inp_batch, out_batch, loss_func, sos_tok):
enc_out, enc_hidden, enc_cell = encoder(inp_batch)
# dec_hidden = enc_hidden
# dec_cell = enc_cell
# print(out_batch.shape)
sos_tok = sos_tok.expand(1, inp_batch.shape[1]).to(device)
out_batch_in = torch.cat([sos_tok, out_batch], dim=0)
dec_out, _, _ = decoder(out_batch_in, enc_hidden, enc_cell)
dec_out = dec_out[:-1].reshape(-1, dec_out.shape[-1])
out_batch = out_batch.reshape(-1)
with torch.no_grad():
dec_acc = dec_out.detach()
out_acc = out_batch.detach()
top_v, top_i = torch.topk(dec_acc, 1, dim=-1)
top_i = top_i.squeeze()
total = top_i.shape[0]
# print(f"top:{top_i.shape[0]}\tout:{out_batch.shape}")
diff = (top_i-out_acc)
diff[diff != 0] = 1
# print(f"diff:{diff.nonzero()}")
wrong = diff[diff.nonzero(as_tuple=True)].sum()
# print(f"wrong:{wrong}")
correct = total - wrong
# print(f"correct:{correct}\ttotal:{total}")
acc = float(correct)/float(total)
return loss_func(dec_out, out_batch), acc
def loader_epoch(encoder, decoder, loader, device, loss_func, optims=None, sos_tok=None):
if optims is None:
encoder.eval()
decoder.eval()
else:
encoder.train()
decoder.train()
total_loss = 0
accuracy = 0
for idx, (inp_batch, out_batch) in tqdm(enumerate(loader), total=(len(loader))):
inp_batch = inp_batch.to(device)
out_batch = out_batch.to(device)
# if idx == len(loader)-1:
# set_trace()
if optims is None:
with torch.no_grad():
loss, acc = model_iter(encoder, decoder, inp_batch, out_batch, loss_func, sos_tok)
else:
loss, acc = model_iter(encoder, decoder, inp_batch, out_batch, loss_func, sos_tok)
total_loss += loss.item()
accuracy += acc
if optims is not None:
for optim in optims:
optim.zero_grad()
loss.backward()
for optim in optims:
optim.step()
return total_loss / idx, accuracy / idx
I’m training on jupyter notebook, could that be a potential issue?