A way to optimize the collate_batch?

I am getting lot of time spent on collate batch function of RNN example below, wondering anyway to optimize it?? I will try attaching logs with timestamps but can not paste here due to excess length.

import torch
import torch.nn as nn
import code
from functools import wraps
from datetime import datetime

from torchtext.datasets import IMDB
train_dataset = IMDB(split=‘train’)
test_dataset = IMDB(split=‘test’)
import re

CONFIG_USE_ROCM=0

1. create dataset

from torch.utils.data.dataset import random_split

torch.manual_seed(1)
train_dataset, valid_dataset = random_split(list(train_dataset), [20000, 5000])
test_dataset=list(test_dataset)

2. find unique tokens

from collections import Counter, OrderedDict

def print_fcn_name(func):
@wraps(func)
def wrapper(*args, **kwargs):
now = datetime.now()
dt_string = now.strftime(“%d/%m/%Y %H:%M:%S”)
print(dt_string, ": “, func.name, " entered…”)
result = func(*args, **kwargs)
return result

return wrapper

#@print_fcn_name
def tokenizer(text):
text = re.sub(‘<[^>]*>’, ‘’, text)
emoticons = re.findall(
‘(?::|;|=)(?:-)?(?:)|(|D|P)’, text.lower())
text = re.sub(‘[\W]+’, ’ ’ , text.lower()) +
’ ‘.join(emoticons).replace(’-', ‘’)

tokenized = text.split()
return tokenized

token_counts = Counter()
for label, line in train_dataset:
tokens = tokenizer(line)
token_counts.update(tokens)
print(‘Vocab-size:’, len(token_counts))

3. encoding each unique token into integres

from torchtext.vocab import vocab
sorted_by_freq_tuples = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
vocab = vocab(ordered_dict)
vocab.insert_token(‘’, 0)
vocab.insert_token(‘’, 1)
vocab.set_default_index(1)

print([vocab[token] for token in [‘this’, ‘is’, ‘an’, ‘example’]])

3a. define the functions for transformation.

text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: 1. if x == ‘pos’ else 0.

3b. wrap the encode and transformation function.

@print_fcn_name
def collate_batch(batch):
label_list, text_list, lengths = [], [], []
for _label, _text in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
lengths.append(processed_text.size(0))

if CONFIG_USE_ROCM:
    label_list = torch.tensor(label_list, device='cuda')
    lengths = torch.tensor(lengths, device='cuda')
else:
label_list = torch.tensor(label_list)
    lengths = torch.tensor(lengths)

padded_text_list = nn.utils.rnn.pad_sequence(text_list, batch_first=True)
padded_text_list.to('cuda')
#code.interact(local=locals())   
return padded_text_list, label_list, lengths

Take a small batch

from torch.utils.data import DataLoader
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False, collate_fn=collate_batch)

text_batch, label_batch, length_batch = next(iter(dataloader))
print(text_batch)
print(label_batch)
print(length_batch)

batch_size=32
train_dl=DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
valid_dl=DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_dl=DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

#code.interact(local=locals())

embedding=nn.Embedding(num_embeddings=10, embedding_dim=3, padding_idx=0)

a batch of 2 samples of 4 indices each

text_encoded_input = torch.LongTensor([[1,2,4,5],[4,3,2,0]])
print(embedding(text_encoded_input))

class RNN(nn.Module):
def init(self, vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size):
super().init()
self.embedding=nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.rnn=nn.LSTM(embed_dim, rnn_hidden_size, batch_first=True)
self.fc1 = nn.Linear(rnn_hidden_size, fc_hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(fc_hidden_size, 1)
self.sigmoid = nn.Sigmoid()

def forward(self, text, lengths):
   out = self.embedding(text)
    out = nn.utils.rnn.pack_padded_sequence(out, lengths.cpu().numpy(), enforce_sorted=False, batch_first=True)
    out, (hidden,cell) = self.rnn(out)
    out = hidden[-1,:,:]
    out = self.fc1(out)
    out = self.relu(out)
    out = self.fc2(out)
    out = self.sigmoid(out)
    return out

vocab_size = len(vocab)
embed_dim = 20
rnn_hidden_size=64
fc_hidden_size=64
torch.manual_seed(1)

print(“initializing model…”)

model=RNN(vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size)
if CONFIG_USE_ROCM:
model.to(‘cuda’)

@print_fcn_name
def train(dataloader):
model.train()
total_acc, total_loss=0,0
for text_batch, label_batch, lengths in dataloader:
optimizer.zero_grad()
pred=model(text_batch, lengths)[:,0]
loss=loss_fn(pred, label_batch)
loss.backward()
optimizer.step()
total_acc += ((pred >= 0.5).float() == label_batch).float().sum().item()
total_loss += loss.item() * label_batch.size(0)

return total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset)

@print_fcn_name
def evaluate(dataloader):
print(“evaludate entered…”)
print(dataloader, len(dataloader), type(dataloader))
model.eval()
total_acc, total_loss = 0,0
with torch.no_grad():
for text_batch, label_batch, lengths in dataloader:
pred=model(text_batch, lengths)[:,0]
loss=loss_fn(pred, label_batch)
total_acc+=((pred>=0.5).float() == label_batch).float().sum().item()
total_loss += loss.item()*label_batch.size(0)

#code.interact(local=locals())
print("total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset): ", total_acc/len(dataloader.dataset), total_loss/len(dataloade$
return total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset)

print(“setting loss function + optimizer…”)
loss_fn = nn.BCELoss()
optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs=5
torch.manual_seed(1)

print(“start training…”)
for epoch in range(num_epochs):
print("EPOCH: ", epoch)
acc_train, loss_train = train(train_dl)
acc_valid, loss_valid = evaluate(valid_dl)
print(f’Epoch {epoch} accuracy: {acc_train:.4f}’
f’ val_accuracy: {acc_valid:.4f}')

print(“Evaluate…”)
acc_test, _ = evaluate(test_dl)
print(f’test accuracy: {acc_test:.4f}')

shortened log:

[BEGIN] 9/27/2022 12:49:26 AM
[root@ixt-hq-180 ch15-rnn]# python3 p513-movie-review.py
Vocab-size: 69023
[11, 7, 35, 457]
27/09/2022 07:49:32 : collate_batch entered…
tensor([[ 35, 1739, 7, 449, 721, 6, 301, 4, 787, 9,
4, 18, 44, 2, 1705, 2460, 186, 25, 7, 24,
100, 1874, 1739, 25, 7, 34415, 3568, 1103, 7517, 787,
5, 2, 4991, 12401, 36, 7, 148, 111, 939, 6,
11598, 2, 172, 135, 62, 25, 3199, 1602, 3, 928,
1500, 9, 6, 4601, 2, 155, 36, 14, 274, 4,
42945, 9, 4991, 3, 14, 10296, 34, 3568, 8, 51,
148, 30, 2, 58, 16, 11, 1893, 125, 6, 420,
1214, 27, 14542, 940, 11, 7, 29, 951, 18, 17,
15994, 459, 34, 2480, 15211, 3713, 2, 840, 3200, 9,
3568, 13, 107, 9, 175, 94, 25, 51, 10297, 1796,
27, 712, 16, 2, 220, 17, 4, 54, 722, 238,
395, 2, 787, 32, 27, 5236, 3, 32, 27, 7252,
5118, 2461, 6390, 4, 2873, 1495, 15, 2, 1054, 2874,
155, 3, 7015, 7, 409, 9, 41, 220, 17, 41,
390, 3, 3925, 807, 37, 74, 2858, 15, 10297, 115,
31, 189, 3506, 667, 163, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[ 216, 175, 724, 5, 11, 18, 10, 226, 110, 14,
182, 78, 8, 13, 24, 182, 78, 8, 13, 166,
182, 50, 150, 24, 85, 2, 4031, 5935, 107, 96,
28, 1867, 602, 19, 52, 162, 21, 1698, 8, 6,
1181, 367, 2, 351, 10, 140, 419, 4, 333, 5,
6022, 7136, 5055, 1209, 10892, 32, 219, 9, 2, 405,
1413, 13, 4031, 13, 1099, 7, 85, 19, 2, 20,
1018, 4, 85, 565, 34, 24, 807, 55, 5, 68,
658, 10, 507, 8, 4, 668, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[ 10, 121, 24, 28, 98, 74, 589, 9, 149, 2,
7372, 3030, 14543, 1012, 520, 2, 985, 2327, 5, 16847,
5479, 19, 25, 67, 76, 3478, 38, 2, 7372, 3,
25, 67, 76, 2951, 34, 35, 10893, 155, 449, 29495,
23725, 10, 67, 2, 554, 12, 14543, 67, 91, 4,
50, 20, 19, 8, 67, 24, 4228, 2, 2142, 37,
33, 3478, 87, 3, 2564, 160, 155, 11, 634, 126,
24, 158, 72, 286, 13, 373, 2, 4804, 19, 2,
7372, 6794, 6, 30, 128, 73, 48, 10, 886, 8,
13, 24, 4, 85, 20, 19, 8, 13, 35, 218,
3, 428, 710, 2, 107, 936, 7, 54, 72, 223,
3, 10, 96, 122, 2, 103, 54, 72, 82, 2,
658, 202, 2, 106, 293, 103, 7, 1193, 3, 3031,
708, 5760, 3, 2918, 3991, 706, 3327, 349, 148, 286,
13, 139, 6, 2, 1501, 750, 29, 1407, 62, 65,
2612, 71, 40, 14, 4, 547, 9, 62, 8, 7943,
71, 14, 2, 5687, 5, 4868, 3111, 6, 205, 2,
18, 55, 2075, 3, 403, 12, 3111, 231, 45, 5,
271, 3, 68, 1400, 7, 9774, 932, 10, 102, 2,
20, 143, 28, 76, 55, 3810, 9, 2723, 5, 12,
10, 379, 2, 7372, 15, 4, 50, 710, 8, 13,
24, 887, 32, 31, 19, 8, 13, 428],
[18923, 7, 4, 4753, 1669, 12, 3019, 6, 4, 13906,
502, 40, 25, 77, 1588, 9, 115, 6, 21713, 2,
90, 305, 237, 9, 502, 33, 77, 376, 4, 16848,
847, 62, 77, 131, 9, 2, 1580, 338, 5, 18923,
32, 2, 1980, 49, 157, 306, 21713, 46, 981, 6,
10298, 2, 18924, 125, 9, 502, 3, 453, 4, 1852,
630, 407, 3407, 34, 277, 29, 242, 2, 20200, 5,
18923, 77, 95, 41, 1833, 6, 2105, 56, 3, 495,
214, 528, 2, 3479, 2, 112, 7, 181, 1813, 3,
597, 5, 2, 156, 294, 4, 543, 173, 9, 1562,
289, 10038, 5, 2, 20, 26, 841, 1392, 62, 130,
111, 72, 832, 26, 181, 12402, 15, 69, 183, 6,
66, 55, 936, 5, 2, 63, 8, 7, 43, 4,
78, 23726, 15995, 13, 20, 17, 800, 5, 392, 59,
3992, 3, 371, 103, 2596, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]])
tensor([1., 1., 1., 0.])
tensor([165, 86, 218, 145])
tensor([[[ 0.7039, -0.8321, -0.4651],
[-0.3203, 2.2408, 0.5566],
[-0.4643, 0.3046, 0.7046],
[-0.7106, -0.2959, 0.8356]],

    [[-0.4643,  0.3046,  0.7046],
     [ 0.0946, -0.3531,  0.9124],
     [-0.3203,  2.2408,  0.5566],
     [ 0.0000,  0.0000,  0.0000]]], grad_fn=<EmbeddingBackward0>)

initializing model…
setting loss function + optimizer…
start training…
EPOCH: 0
27/09/2022 07:49:33 : train entered…
27/09/2022 07:49:33 : collate_batch entered…
27/09/2022 07:49:35 : collate_batch entered…
27/09/2022 07:49:36 : collate_batch entered…

27/09/2022 08:00:33 : collate_batch entered…
27/09/2022 08:00:34 : collate_batch entered…
27/09/2022 08:00:35 : collate_batch entered…
27/09/2022 08:00:37 : evaluate entered…
evaludate entered…
<torch.utils.data.dataloader.DataLoader object at 0x7f9bd7604f40> 157 <class ‘torch.utils.data.dataloader.DataLoader’>
27/09/2022 08:00:37 : collate_batch entered…
27/09/2022 08:00:37 : collate_batch entered…
27/09/2022 08:00:38 : collate_batch entered…

27/09/2022 08:02:31 : collate_batch entered…
27/09/2022 08:02:32 : collate_batch entered…
27/09/2022 08:02:33 : collate_batch entered…
27/09/2022 08:02:33 : collate_batch entered…
total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset): 0.6704 0.6209311981201172
Epoch 0 accuracy: 0.5841 val_accuracy: 0.6704
EPOCH: 1
27/09/2022 08:02:33 : train entered…
27/09/2022 08:02:33 : collate_batch entered…
27/09/2022 08:02:34 : collate_batch entered…


27/09/2022 08:13:32 : collate_batch entered…
27/09/2022 08:13:33 : collate_batch entered…
27/09/2022 08:13:34 : collate_batch entered…
27/09/2022 08:13:35 : collate_batch entered…
27/09/2022 08:13:36 : evaluate entered…
evaludate entered…
<torch.utils.data.dataloader.DataLoader object at 0x7f9bd7604f40> 157 <class ‘torch.utils.data.dataloader.DataLoader’>
27/09/2022 08:13:36 : collate_batch entered…
27/09/2022 08:13:37 : collate_batch entered…

27/09/2022 08:15:30 : collate_batch entered…
27/09/2022 08:15:31 : collate_batch entered…
27/09/2022 08:15:32 : collate_batch entered…
27/09/2022 08:15:32 : collate_batch entered…
total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset): 0.6786 0.5940916432380676
Epoch 1 accuracy: 0.7028 val_accuracy: 0.6786
EPOCH: 2
27/09/2022 08:15:32 : train entered…
27/09/2022 08:15:32 : collate_batch entered…
27/09/2022 08:15:33 : collate_batch entered…
27/09/2022 08:15:34 : collate_batch entered…
27/09/2022 08:15:35 : collate_batch entered…
27/09/2022 08:15:36 : collate_batch entered…

27/09/2022 08:26:30 : collate_batch entered…
27/09/2022 08:26:31 : collate_batch entered…
27/09/2022 08:26:32 : collate_batch entered…
27/09/2022 08:26:34 : collate_batch entered…
27/09/2022 08:26:35 : evaluate entered…
evaludate entered…
<torch.utils.data.dataloader.DataLoader object at 0x7f9bd7604f40> 157 <class ‘torch.utils.data.dataloader.DataLoader’>
27/09/2022 08:26:35 : collate_batch entered…
27/09/2022 08:26:36 : collate_batch entered…
27/09/2022 08:26:36 : collate_batch entered…
27/09/2022 08:26:37 : collate_batch entered…

27/09/2022 08:28:16 : collate_batch entered…
27/09/2022 08:28:17 : collate_batch entered…
27/09/2022 08:28:18 : collate_batch entered…
27/09/2022 08:28:19 : collate_batch entered…
27/09/2022 08:28:19 : collate_batch entered…
total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset): 0.7462 0.5252116619586945
Epoch 2 accuracy: 0.7499 val_accuracy: 0.7462
EPOCH: 3
27/09/2022 08:28:30 : train entered…
27/09/2022 08:28:30 : collate_batch entered…
27/09/2022 08:28:31 : collate_batch entered…
27/09/2022 08:28:32 : collate_batch entered…
27/09/2022 08:28:33 : collate_batch entered…
27/09/2022 08:28:34 : collate_batch entered…
27/09/2022 09:03:41 : collate_batch entered…
27/09/2022 09:03:42 : collate_batch entered…
27/09/2022 09:03:43 : collate_batch entered…
total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset): 0.82616 0.40658646482467653
test accuracy: 0.8262