Cdist vs matmul

Hi,

I am trying to build a video retrieval system using cosine similarity. L2 distance could also be used as it could be written as || a - b || = 2 - 2 * <a, b>, where a, b are both normalized vectors.

Now I have two matrice A: [N x d], B: [M x d]

L2 distance can be calculated in PyTorch as torch.pdist(A, B), cosine similarity as inner product torch.mm(A, B.transpose(0, 1)). However, I found later to be much slower than the former. Any idea why?

Below is the code I used to do the comparison.

import time
import torch
import torch.nn.functional as F
import numpy as np

def compare_l2dist_inner_product_time(n_videos=2000, d=256, n_query=1000, n_runs=5):
    st_time = time.time()
    fake_database = F.normalize(torch.randn((n_videos, d), dtype=torch.float32).cuda(), dim=1, p=2)
    fake_query = F.normalize(torch.randn((n_query, d), dtype=torch.float32).cuda(), dim=1, p=2)
    print("Construct fake database + query time {}".format(time.time() - st_time))
    print("fake_database shape {} fake_query shape {}".format(fake_database.shape, fake_query.shape))

    times_l2dist = []
    for _ in range(n_runs):
        st_time = time.time()
        l2_dist = torch.cdist(fake_query, fake_database, p=2)  # (n_query, n_videos)
        times_l2dist.append(time.time() - st_time)
    avg_time_l2dist = np.mean(times_l2dist)
    print("L2 Distance time {}".format(avg_time_l2dist))

    times_ip = []
    fake_database = fake_database.transpose(0, 1)
    for _ in range(n_runs):
        st_time = time.time()
        inner_product = torch.mm(fake_query, fake_database)  # (n_query, n_videos)
        times_ip.append(time.time() - st_time)
    avg_time_ip = np.mean(times_ip)
    print("Inner Product time {}".format(avg_time_ip))

compare_l2dist_inner_product_time()

Output:

Construct fake database + query time 7.20833158493042
fake_database shape torch.Size([2000, 256]) fake_query shape torch.Size([1000, 256])
L2 Distance time 5.9604644775390625e-05
Inner Product time 0.07725939750671387

OK I figured it out, seems PyTorch need some time to allocate cuda memory.

Not only this, but CUDA operations are executed asynchronously, so you should synchronize before starting and stopping the timer using torch.cuda.synchronize().

Thanks! Added torch.cuda.synchronize() as @ptrblck suggested, also added some warmup runs to make sure it does not count in memory allocation time.

def compare_l2dist_inner_product_time(n_videos=2000, d=256, n_query=1000, n_runs=10, n_warmup_runs=10):
    torch.cuda.synchronize()
    st_time = time.time()
    fake_database = F.normalize(torch.randn((n_videos, d), dtype=torch.float32).cuda(), dim=1, p=2)
    fake_query = F.normalize(torch.randn((n_query, d), dtype=torch.float32).cuda(), dim=1, p=2)
    torch.cuda.synchronize()
    print("Construct fake database + query time {}".format(time.time() - st_time))
    print("fake_database shape {} fake_query shape {}".format(fake_database.shape, fake_query.shape))

    times_l2dist = []
    for _ in range(n_warmup_runs + n_runs):
        torch.cuda.synchronize()
        st_time = time.time()
        l2_dist = torch.cdist(fake_query, fake_database, p=2)  # (n_query, n_videos)
        torch.cuda.synchronize()
        times_l2dist.append(time.time() - st_time)
    avg_time_l2dist = np.mean(times_l2dist[n_warmup_runs:])
    print("L2 Distance time {}".format(avg_time_l2dist))

    times_ip = []
    fake_database = fake_database.transpose(0, 1)
    for _ in range(n_warmup_runs + n_runs):
        torch.cuda.synchronize()
        st_time = time.time()
        inner_product = torch.mm(fake_query, fake_database)  # (n_query, n_videos)
        torch.cuda.synchronize()
        times_ip.append(time.time() - st_time)
    avg_time_ip = np.mean(times_ip[n_warmup_runs:])
    print("Inner Product time {}".format(avg_time_ip))

Here are the outputs (On an RTX 2080Ti):

Construct fake database + query time 0.008526802062988281
fake_database shape torch.Size([2000, 256]) fake_query shape torch.Size([1000, 256])
L2 Distance time 0.014820019404093424
Inner Product time 0.00018693606058756512

The results look more reasonable now, though I did not expect Inner Product to be so fast compared to torch.cdist.

Thanks for the code! I wouldn’t expect such a difference, so I just rerun the code and got:

Construct fake database + query time 0.018206119537353516
fake_database shape torch.Size([2000, 256]) fake_query shape torch.Size([1000, 256])
L2 Distance time 0.00036679506301879884
Inner Product time 0.0001362919807434082

on a TitanV.

1 Like

That’s strange. I tried on a Titan X (Pascal) and 1080Ti and got similar results as my 2080Ti. I am using Python 3.7, PyTorch 1.3.1 with cuda 10, which versions are you using?

I’m using 1.4.0.dev20191109 with Python3.7 and CUDA10.1.

After using 1.4.0.dev20191109 with Python3.7 and CUDA10.1. I get a similar performance. It seems like a PyTorch1.3.1 or CUDA 10.0 problem.

1 Like