Backward becomes extremely slow

I’m working with a very simple script with TextCNN. Below is a small working example. With MAX_LENGTH being 64, it works normally. However, when MAX_LENGTH is changed from 64 to something like 128, the speed becomes extremely slow (at least 100x slower, and seems due to backward) on GPU. There should be something wrong, but I cannot figure it out.

Environment:

  • python 3.8
  • torch 1.8.0+cu111
import time

import numpy as np
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F


NUM_SAMPLES = 10_000
VOCAB_SIZE = 5000
NUM_CLASSES = 10

MAX_LENGTH = 64


class TextCNN(nn.Module):
    def __init__(
        self, num_classes, num_embeddings, embedding_dim=300,
        num_filters=256, region_sizes=(2, 3, 4), dropout_rate=0.5,
    ):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.conv2ds = nn.ModuleList(
            nn.Conv2d(1, num_filters, kernel_size=(region_size, embedding_dim))
            for region_size in region_sizes
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(num_filters * len(region_sizes), num_classes)

    @staticmethod
    def _conv_and_pool(x, conv2d):
        """(batch_size, 1, seq_len, embedding_dim)"""
        x = F.relu(conv2d(x)).squeeze(3)  # (batch_size, num_filters, H_out)
        x = F.max_pool1d(x, kernel_size=x.size(2)).squeeze(2)
        return x  # (batch_size, num_filters)

    def forward(self, inputs):
        """(batch_size, seq_len)"""
        x = self.embedding(inputs).unsqueeze(1)
        x = torch.cat(
            [self._conv_and_pool(x, conv2d) for conv2d in self.conv2ds], dim=1,
        )  # (batch_size, num_filters * len(region_sizes))
        x = self.dropout(x)
        x = self.linear(x)
        return x  # (batch_size, num_classes)


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, max_length):
        self.inputs = np.random.randint(VOCAB_SIZE, size=(NUM_SAMPLES, max_length))
        self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return {
            'inputs': torch.tensor(self.inputs[idx]),
            'labels': torch.tensor(self.labels[idx]),
        }


train_dataset = MyDataset(max_length=MAX_LENGTH)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128)

device = torch.device('cuda')
model = TextCNN(num_classes=NUM_CLASSES, num_embeddings=VOCAB_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters())

model.train()

s = time.time()
for i, batch_data in enumerate(train_dataloader, start=1):
    print(i)
    labels = batch_data.pop('labels').to(device)
    inputs = batch_data.pop('inputs').to(device)
    logits = model(inputs)
    loss_fct = nn.CrossEntropyLoss()
    loss = loss_fct(logits, labels)
    optimizer.zero_grad()

    torch.cuda.synchronize() 
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()

    loss.backward()

    torch.cuda.synchronize()
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))

    optimizer.step()
    if i == 5:
        break
print(time.time() - s)

I cannot reproduce the issue and get these results:

# MAX_LENGTH = 128
0.044320106506347656
# MAX_LENGTH = 64
0.035507917404174805

for a 3090 using the current nightly binaries with CUDA 11.7 and after adding the missing device synchronizations due to the async execution on the GPU.

Thanks for trying. I added the torch.cuda.synchronize() code (not sure if I use it correctly). Below is my output on a 3080Ti.

MAX_LENGTH = 128
1
1063.4217529296875
2
1024.489501953125
3
1024.0513916015625
4
1027.6160888671875
5
1031.493896484375
6.091180801391602
MAX_LENGTH = 64
1
13.087743759155273
2
4.733119964599609
3
4.407296180725098
4
4.50761604309082
5
4.457215785980225
0.9030911922454834

You might want to update to the latest release in this case as older releases might have already fixed performance regressions.

Thanks, upgrading torch works :joy: