Simple Embedding Model: Training too slow

Model:

class RecoEmbeddings(nn.Module):
    
    def __init__(self, n_customers, n_articles, embedding_dim, use_sparse=True):

        super().__init__()
        
        self.embed_customer = nn.Embedding(n_customers, embedding_dim, sparse=use_sparse)
        self.bias_customer = nn.Embedding(n_customers, 1, sparse=use_sparse)
        self.embed_article = nn.Embedding(n_articles, embedding_dim, sparse=use_sparse)
        self.bias_article = nn.Embedding(n_articles, 1, sparse=use_sparse)
        
        self.activation = nn.Sigmoid()
    
    def forward(self, x):
        
        customers = self.embed_customer(x[:, 0])
        articles = self.embed_article(x[:, 1])
        
        res = (customers * articles).sum(dim=1, keepdim=True)
        res += (self.bias_customer(x[:, 0]) + self.bias_article(x[:, 1]))
        
        return self.activation(res)

Model initilization:

model = RecoEmbeddings(n_customers, n_articles, embedding_dim=8, use_sparse=True)
!nvidia-smi
summary(model, input_data=torch.ones((1, 2), dtype=torch.int32), col_names=["input_size", "output_size", "num_params", "mult_adds"])
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    26W /  70W |    608MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
============================================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Mult-Adds
============================================================================================================================================
RecoEmbeddings                           [1, 2]                    [1, 1]                    --                        --
├─Embedding: 1-1                         [1]                       [1, 8]                    13,709,560                13,709,560
├─Embedding: 1-2                         [1]                       [1, 8]                    101,544                   101,544
├─Embedding: 1-3                         [1]                       [1, 1]                    1,713,695                 1,713,695
├─Embedding: 1-4                         [1]                       [1, 1]                    12,693                    12,693
├─Sigmoid: 1-5                           [1, 1]                    [1, 1]                    --                        --
============================================================================================================================================
Total params: 15,537,492
Trainable params: 15,537,492
Non-trainable params: 0
Total mult-adds (M): 15.54
============================================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 62.15
Estimated Total Size (MB): 62.15
============================================================================================================================================

Training Loop:

sample=100000
batch_size=1024

model.to(device)
for epoch in tqdm(range(24), desc="Training: "): # , position=0, leave=True
    dataloader = PortaDataSet.get_data_loader(df_votes, batch_size=batch_size, sample=sample)
    
    for batch_idx, batch in tqdm(enumerate(dataloader), desc=f"Epoch {epoch+1}: ", total=len(dataloader), position=0, leave=True):
        # get the inputs; data is a list of [inputs, targets]
        X, y = batch
        X = X.to(device)
        y = y.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        y_hat = model(X)
        loss = criterion(y_hat, y)
        _ = loss.backward()
        _ = optimizer.step()
        
    print("Loss:", loss.item())

The training is too slow. It takes around 2 minutes per iteration and I have ~1500 iterations per epoch. Is this expected? The model is relatively quite small, 15M parameters. I was not expecting it to be this slow!

What did I try so far to improve performance?

  • Reduced embedding dimensions
  • Changed sparse=True in Embedding layers with almost 20% speed gain
  • Played around with batch sizes and optimizers