I’m working with a very simple script with TextCNN. Below is a small working example. With MAX_LENGTH
being 64, it works normally. However, when MAX_LENGTH
is changed from 64 to something like 128, the speed becomes extremely slow (at least 100x slower, and seems due to backward) on GPU. There should be something wrong, but I cannot figure it out.
Environment:
- python 3.8
- torch 1.8.0+cu111
import time
import numpy as np
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
NUM_SAMPLES = 10_000
VOCAB_SIZE = 5000
NUM_CLASSES = 10
MAX_LENGTH = 64
class TextCNN(nn.Module):
def __init__(
self, num_classes, num_embeddings, embedding_dim=300,
num_filters=256, region_sizes=(2, 3, 4), dropout_rate=0.5,
):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.conv2ds = nn.ModuleList(
nn.Conv2d(1, num_filters, kernel_size=(region_size, embedding_dim))
for region_size in region_sizes
)
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(num_filters * len(region_sizes), num_classes)
@staticmethod
def _conv_and_pool(x, conv2d):
"""(batch_size, 1, seq_len, embedding_dim)"""
x = F.relu(conv2d(x)).squeeze(3) # (batch_size, num_filters, H_out)
x = F.max_pool1d(x, kernel_size=x.size(2)).squeeze(2)
return x # (batch_size, num_filters)
def forward(self, inputs):
"""(batch_size, seq_len)"""
x = self.embedding(inputs).unsqueeze(1)
x = torch.cat(
[self._conv_and_pool(x, conv2d) for conv2d in self.conv2ds], dim=1,
) # (batch_size, num_filters * len(region_sizes))
x = self.dropout(x)
x = self.linear(x)
return x # (batch_size, num_classes)
class MyDataset(torch.utils.data.Dataset):
def __init__(self, max_length):
self.inputs = np.random.randint(VOCAB_SIZE, size=(NUM_SAMPLES, max_length))
self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES)
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return {
'inputs': torch.tensor(self.inputs[idx]),
'labels': torch.tensor(self.labels[idx]),
}
train_dataset = MyDataset(max_length=MAX_LENGTH)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128)
device = torch.device('cuda')
model = TextCNN(num_classes=NUM_CLASSES, num_embeddings=VOCAB_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters())
model.train()
s = time.time()
for i, batch_data in enumerate(train_dataloader, start=1):
print(i)
labels = batch_data.pop('labels').to(device)
inputs = batch_data.pop('inputs').to(device)
logits = model(inputs)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
optimizer.zero_grad()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
loss.backward()
torch.cuda.synchronize()
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
optimizer.step()
if i == 5:
break
print(time.time() - s)