I have a nested for loop that does some masking and summation as below:
for i in range(200):
for j in range(200):
xlim = bev_limit - i * bev_resolution
ylim = bev_limit - j * bev_resolution
pcm_mask = (transformed_coordinates[0] <= xlim) & \
(transformed_coordinates[0] >= xlim - bev_resolution) & \
(transformed_coordinates[1] < ylim) & \
(transformed_coordinates[1] >= ylim - bev_resolution)
new_pcm[:, i, j] = (0.01 * torch.sum(kron_pcm[:, pcm_mask],
dim=1)).clip(max=1.0)
fractions_pcm[i, j] = (0.01 * torch.sum(pcm_mask)).clip(max=1.0)
In the code above, transformed_coordinates is 3x2000x2000 tensor (first two dimensions represent x and y coordinates like a meshgrid). kron_pcm is also a 3x2000x2000 tensor, and new_pcm is a 3x200x200 tensor.
With all tensors loaded in GPU memory, doing the loop takes about 16.6 seconds, which itself is a big improvement over using NumPy, which takes ~606 seconds. Is there a way to further vectorize this to reduce computation time?
Yes, at the cost of materializing a large mask tensor.
The idea is to perform the masking by multiplying with zeros and ones, rather
than selecting a subset of elements by “indexing” with a boolean tensor, and
use einsum() to perform the multiplications and summation.
Your example problem is fairly large, and this masking scheme won’t fit into my
4 GB gpu. I have an example script the performs a no-loop mask / einsum()
computation on a smaller version of your problem. It also shows how to apply
the scheme to your full problem by slicing your problem into chunks that do fit
into the gpu.
Here is the script:
import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_properties (0).total_memory)
_ = torch.manual_seed (2023)
nc = 3
nBig = 2000
nLit = 200
smallFac = 2
nSlice = 8
for doFull in (False, True):
print ('doFull:', doFull)
if doFull: print ('perform full-sized computation')
else: print ('perform reduced-sized computation')
nB = nBig
nL = nLit
if not doFull: # reduce size of problem
nB //= smallFac
nL //= smallFac
tc = nLit * torch.rand (nc, nB, nB, device = 'cuda') # some sample data
kron_pcm = torch.randn (nc, nB, nB, device = 'cuda') # some sample data
new_pcm = torch.zeros (nc, nL, nL, device = 'cuda') # could use torch.empty()
print ('tc.shape = ', tc.shape)
print ('kron_pcm.shape = ', kron_pcm.shape)
print ('new_pcm.shape = ', new_pcm.shape)
for i in range (nL): # version with loops
for j in range (nL):
pcm_mask = (tc[0] > i) & (tc[0] < i + 1) & (tc[1] > j) & (tc[1] < j + 1)
new_pcm[:, i, j] = torch.sum (kron_pcm[:, pcm_mask], dim = 1)
if not doFull: # loop-free version -- reduced-size problem fits in 4 GB
print ('perform full masking and single einsum() ...')
ma = torch.logical_and (tc[0, None] > torch.arange (nL, device = 'cuda')[:, None, None], tc[0, None] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
mb = torch.logical_and (tc[1, None] > torch.arange (nL, device = 'cuda')[:, None, None], tc[1, None] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
new_pcmB = torch.einsum ('cij, mij, nij -> cmn', kron_pcm, ma, mb)
else: # full-size problem doesn't fit in 4 GB -- loop over slices
print ('perform masking and einsum() in slices, nSlice =', nSlice)
for slice in range (nSlice):
sa = slice * (nB // nSlice)
sb = (slice + 1) * (nB // nSlice)
print ('compute masks for slize [%d:%d] ...' % (sa, sb))
ma = torch.logical_and (tc[0, None, sa:sb, :] > torch.arange (nL, device = 'cuda')[:, None, None], tc[0, None, sa:sb, :] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
mb = torch.logical_and (tc[1, None, sa:sb, :] > torch.arange (nL, device = 'cuda')[:, None, None], tc[1, None, sa:sb, :] < (torch.arange (nL, device = 'cuda') + 1)[:, None, None]).float()
print ('perform einsum() for slize [%d:%d] ...' % (sa, sb))
new_pcmB_slice = torch.einsum ('cij, mij, nij -> cmn', kron_pcm[:, sa:sb], ma, mb)
if slice == 0: new_pcmB = new_pcmB_slice
else: new_pcmB += new_pcmB_slice
print ('new_pcmB.shape = ', new_pcmB.shape)
print ('torch.allclose (new_pcmB, new_pcm, atol = 1.e-4) =', torch.allclose (new_pcmB, new_pcm, atol = 1.e-4))
And here is its output:
2.1.0
11.8
4236312576
doFull: False
perform reduced-sized computation
tc.shape = torch.Size([3, 1000, 1000])
kron_pcm.shape = torch.Size([3, 1000, 1000])
new_pcm.shape = torch.Size([3, 100, 100])
perform full masking and single einsum() ...
new_pcmB.shape = torch.Size([3, 100, 100])
torch.allclose (new_pcmB, new_pcm, atol = 1.e-4) = True
doFull: True
perform full-sized computation
tc.shape = torch.Size([3, 2000, 2000])
kron_pcm.shape = torch.Size([3, 2000, 2000])
new_pcm.shape = torch.Size([3, 200, 200])
perform masking and einsum() in slices, nSlice = 8
compute masks for slize [0:250] ...
perform einsum() for slize [0:250] ...
compute masks for slize [250:500] ...
perform einsum() for slize [250:500] ...
compute masks for slize [500:750] ...
perform einsum() for slize [500:750] ...
compute masks for slize [750:1000] ...
perform einsum() for slize [750:1000] ...
compute masks for slize [1000:1250] ...
perform einsum() for slize [1000:1250] ...
compute masks for slize [1250:1500] ...
perform einsum() for slize [1250:1500] ...
compute masks for slize [1500:1750] ...
perform einsum() for slize [1500:1750] ...
compute masks for slize [1750:2000] ...
perform einsum() for slize [1750:2000] ...
new_pcmB.shape = torch.Size([3, 200, 200])
torch.allclose (new_pcmB, new_pcm, atol = 1.e-4) = True
Thanks a lot, your solution worked flawlessly! The only change I made was to use half (float16) instead of float to cut down on VRAM, since my data is in [0, 1]. On a 3090 the example (2000x2000) gets done in ~80ms and takes up ~8GB.