Vectorization of Masking in Nested for Loop

Hi,

I have a nested for loop that does some masking and summation as below:

for i in range(200):
    for j in range(200):
        xlim = bev_limit - i * bev_resolution
        ylim = bev_limit - j * bev_resolution

        pcm_mask = (transformed_coordinates[0] <= xlim) & \ 
            (transformed_coordinates[0] >= xlim - bev_resolution) & \
            (transformed_coordinates[1] < ylim) & \
            (transformed_coordinates[1] >= ylim - bev_resolution)
            
        new_pcm[:, i, j] = (0.01 * torch.sum(kron_pcm[:, pcm_mask],
                                                 dim=1)).clip(max=1.0)
        fractions_pcm[i, j] = (0.01 * torch.sum(pcm_mask)).clip(max=1.0)

In the code above, transformed_coordinates is 3x2000x2000 tensor (first two dimensions represent x and y coordinates like a meshgrid). kron_pcm is also a 3x2000x2000 tensor, and new_pcm is a 3x200x200 tensor.

With all tensors loaded in GPU memory, doing the loop takes about 16.6 seconds, which itself is a big improvement over using NumPy, which takes ~606 seconds. Is there a way to further vectorize this to reduce computation time?

Hi Goodarz!

Yes, at the cost of materializing a large mask tensor.

The idea is to perform the masking by multiplying with zeros and ones, rather
than selecting a subset of elements by “indexing” with a boolean tensor, and
use einsum() to perform the multiplications and summation.

Your example problem is fairly large, and this masking scheme won’t fit into my
4 GB gpu. I have an example script the performs a no-loop mask / einsum()
computation on a smaller version of your problem. It also shows how to apply
the scheme to your full problem by slicing your problem into chunks that do fit
into the gpu.

Here is the script:

import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_properties (0).total_memory)

_ = torch.manual_seed (2023)

nc = 3
nBig = 2000
nLit = 200

smallFac = 2
nSlice = 8

for  doFull in (False, True):
    print ('doFull:', doFull)
    if  doFull:  print ('perform full-sized computation')
    else:        print ('perform reduced-sized computation')
    
    nB = nBig
    nL = nLit
    if  not doFull:         # reduce size of problem
        nB //= smallFac
        nL //= smallFac
    
    tc = nLit * torch.rand (nc, nB, nB, device = 'cuda')   # some sample data
    kron_pcm = torch.randn (nc, nB, nB, device = 'cuda')   # some sample data
    new_pcm = torch.zeros (nc, nL, nL, device = 'cuda')    # could use torch.empty()
    
    print ('tc.shape        = ', tc.shape)
    print ('kron_pcm.shape  = ', kron_pcm.shape)
    print ('new_pcm.shape   = ', new_pcm.shape)
    
    for  i in range (nL):   # version with loops
        for  j in range (nL):
            pcm_mask = (tc[0] > i)  &  (tc[0] < i + 1)  &  (tc[1] > j)  &  (tc[1] < j + 1)
            new_pcm[:, i, j] = torch.sum (kron_pcm[:, pcm_mask], dim = 1)
    
    if  not doFull:         # loop-free version -- reduced-size problem fits in 4 GB
        print ('perform full masking and single einsum() ...')
        ma = torch.logical_and (tc[0, None] > torch.arange (nL, device = 'cuda')[:, None, None],  tc[0, None] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
        mb = torch.logical_and (tc[1, None] > torch.arange (nL, device = 'cuda')[:, None, None],  tc[1, None] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
        new_pcmB = torch.einsum ('cij, mij, nij -> cmn', kron_pcm, ma, mb)
        
    else:                   # full-size problem doesn't fit in 4 GB -- loop over slices
        print ('perform masking and einsum() in slices, nSlice =', nSlice)
        for  slice in range (nSlice):
            sa = slice * (nB // nSlice)
            sb = (slice + 1) * (nB // nSlice)
            print ('compute masks for slize [%d:%d] ...' % (sa, sb))
            ma = torch.logical_and (tc[0, None, sa:sb, :] > torch.arange (nL, device = 'cuda')[:, None, None],  tc[0, None, sa:sb, :] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
            mb = torch.logical_and (tc[1, None, sa:sb, :] > torch.arange (nL, device = 'cuda')[:, None, None],  tc[1, None, sa:sb, :] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
            print ('perform einsum() for slize [%d:%d] ...' % (sa, sb))
            new_pcmB_slice = torch.einsum ('cij, mij, nij -> cmn', kron_pcm[:, sa:sb], ma, mb)
            if  slice == 0:  new_pcmB  = new_pcmB_slice
            else:            new_pcmB += new_pcmB_slice
    
    print ('new_pcmB.shape  = ', new_pcmB.shape)
    print ('torch.allclose (new_pcmB, new_pcm, atol = 1.e-4) =', torch.allclose (new_pcmB, new_pcm, atol = 1.e-4))

And here is its output:

2.1.0
11.8
4236312576
doFull: False
perform reduced-sized computation
tc.shape        =  torch.Size([3, 1000, 1000])
kron_pcm.shape  =  torch.Size([3, 1000, 1000])
new_pcm.shape   =  torch.Size([3, 100, 100])
perform full masking and single einsum() ...
new_pcmB.shape  =  torch.Size([3, 100, 100])
torch.allclose (new_pcmB, new_pcm, atol = 1.e-4) = True
doFull: True
perform full-sized computation
tc.shape        =  torch.Size([3, 2000, 2000])
kron_pcm.shape  =  torch.Size([3, 2000, 2000])
new_pcm.shape   =  torch.Size([3, 200, 200])
perform masking and einsum() in slices, nSlice = 8
compute masks for slize [0:250] ...
perform einsum() for slize [0:250] ...
compute masks for slize [250:500] ...
perform einsum() for slize [250:500] ...
compute masks for slize [500:750] ...
perform einsum() for slize [500:750] ...
compute masks for slize [750:1000] ...
perform einsum() for slize [750:1000] ...
compute masks for slize [1000:1250] ...
perform einsum() for slize [1000:1250] ...
compute masks for slize [1250:1500] ...
perform einsum() for slize [1250:1500] ...
compute masks for slize [1500:1750] ...
perform einsum() for slize [1500:1750] ...
compute masks for slize [1750:2000] ...
perform einsum() for slize [1750:2000] ...
new_pcmB.shape  =  torch.Size([3, 200, 200])
torch.allclose (new_pcmB, new_pcm, atol = 1.e-4) = True

Best.

K. Frank

1 Like

Thanks a lot, your solution worked flawlessly! The only change I made was to use half (float16) instead of float to cut down on VRAM, since my data is in [0, 1]. On a 3090 the example (2000x2000) gets done in ~80ms and takes up ~8GB.