Optimising torch memory cost of einsum operation

Hello! If I have two tensors of shape [200, 2000, 256] and [250, 1900, 256] (like batch x seq_len x emb_size), and I wanna get multiplication of every embedding in first tensor with every embedding in second tensor, then take max of last dimension and then take sum of last dimension, I mean something like that: “”“torch.einsum('ijk, mnk → ijmn, first_tensor, second_tensor).max(-1).values.sum(-1)”“”, does here exist any method to implement it as fast as einsum (without loops), but with less memory than einsum (because of large dimensions of tensors, 4d tensor is large even without gradients and optimizators)?

Could you post the (slow) PyTorch reference code using loops here, please?

You could have a look at the torch.func.vmap function to vectorize over your batch dimension, although you have two different batch sizes of 200 in the first tensor and 250 in the second tensor and that might complicate things.

As @ptrblck stated, share a minimal reproducible example of your problem.

Thank you for your reply. For example, this one:

import torch

tensor_a = torch.randn(300, 2000, 256)
tensor_b = torch.randn(315, 1990, 256)

result = torch.zeros(300, 315)
x = torch.einsum('ijk, mnk -> ijmn', tensor_a, tensor_b).max(-1).values.sum(1)

for i in range(tensor_a.shape[0]):
    for j in range(tensor_b.shape[0]):
        result[i, j] = (tensor_a[i] @ tensor_b[j].T).max(-1).values.sum(-1)

print(result)
print(x)

The problem of einsum is that it firstly creates 4d tensor of shape 300x2000x315x1990 and if takes a lot of memory with fp32 even without gradients and optimisers

So, I’ve had a look at this problem and I’m not 100% sure if there’s a way to mitigate the memory issue. I don’t think there’s an associativity trick that can be exploited to mitigate the memory issue (at least that I’m aware of).

One thing that might help performance (at least in terms of walltime), is to vectorize the operation and just ‘chunk’ the amount of data that goes through torch.func.vmap to mitigate memory cost and trade off between walltime and memory constraints.

A minimal reproducible example can be found below (for smaller inputs tensors, tensor_a and tensor_b). You can change the mini-batch size (or chunk_size) with the chunk_size_a and chunk_size_b inside the vmap_fn

import torch
from torch import Tensor

batch_a = 20#0
batch_b = 25#0

seq_len_a = 20#00
seq_len_b = 19#00
 
emb_size = 25#6

tensor_a = torch.randn(batch_a, seq_len_a, emb_size)
tensor_b = torch.randn(batch_b, seq_len_b, emb_size)

def original_fn(tensor_a: Tensor, tensor_b: Tensor) -> Tensor:
	result = torch.zeros(batch_a, batch_b)
	x = torch.einsum('ijk, mnk -> ijmn', tensor_a, tensor_b).max(-1).values.sum(1)

	for i in range(tensor_a.shape[0]):
		for j in range(tensor_b.shape[0]):
			result[i, j] = (tensor_a[i] @ tensor_b[j].T).max(-1).values.sum(-1)

	return result

from torch.func import vmap
	
def vmap_fn(tensor_a: Tensor, tensor_b: Tensor) -> Tensor:
	chunk_size_a = None #None equal to full-batch vmap
	chunk_size_b = None 
	
	def _max_fn(a, b):
		return torch.max(a @ b.T, dim=-1).values.sum(dim=-1)
		
	result = vmap(
				vmap(
						_max_fn, in_dims=(None,0), chunk_size=chunk_size_a	
				), in_dims=(0,None), chunk_size=chunk_size_b
			)(tensor_a, tensor_b)
	return result
	
result_og = original_fn(tensor_a, tensor_b)
result_vmap = vmap_fn(tensor_a, tensor_b)

print(torch.allclose(result_og, result_vmap)) #returns True

import timeit

t0 = timeit.Timer(
    stmt='original_fn(tensor_a, tensor_b)',
    setup='from __main__ import original_fn',
    globals={'tensor_a':tensor_a, 'tensor_b':tensor_b})

t1 = timeit.Timer(
    stmt='vmap_fn(tensor_a, tensor_b)',
    setup='from __main__ import vmap_fn',
    globals={'tensor_a':tensor_a, 'tensor_b':tensor_b})

original_time = t0.timeit(100) / 100 * 1e6
vmap_time = t1.timeit(100) / 100 * 1e6
print(f'original_fn(a, b):  {original_time:>5.1f} us')
print(f'vmap_fn(a, b):      {vmap_time:>5.1f} us')
print(f'vmap_fn(a, b)/original_fn(a, b) ratio: {original_time/vmap_time:.3f} times faster') 

from torch.profiler import profile, record_function, ProfilerActivity
#cuda_time_total, cpu_time_total
#cuda_memory_usage, cpu_memory_usage

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
    with record_function("original_fn"):
        original_fn(tensor_a, tensor_b)

print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
    with record_function("vmap_fn"):
        vmap_fn(tensor_a, tensor_b)

print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))

This returns,

True
original_fn(a, b):  17656.7 us
vmap_fn(a, b):      536.8 us
vmap_fn(a, b)/original_fn(a, b) ratio: 32.892 times faster
STAGE:2024-05-19 15:39:42 13888:13888 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-19 15:39:43 13888:13888 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-19 15:39:43 13888:13888 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::einsum         0.03%      48.000us         0.12%     214.000us     214.000us     742.19 Kb           0 b             1  
                  aten::bmm         0.07%     115.000us         0.07%     115.000us     115.000us     742.19 Kb     742.19 Kb             1  
               aten::matmul         0.64%       1.120ms         2.78%       4.866ms       9.732us     742.19 Kb      29.69 Kb           500  
                   aten::mm         2.23%       3.898ms         2.23%       3.898ms       7.796us     742.19 Kb     742.19 Kb           500  
                  aten::max         3.72%       6.502ms         5.80%      10.150ms      20.259us     234.38 Kb     234.38 Kb           501  
                  aten::sum         1.88%       3.295ms         2.09%       3.664ms       7.313us       3.91 Kb       3.91 Kb           501  
                aten::zeros         0.01%      14.000us         0.01%      23.000us      23.000us       1.95 Kb           0 b             1  
                aten::empty         0.01%       9.000us         0.01%       9.000us       9.000us       1.95 Kb       1.95 Kb             1  
                original_fn         8.59%      15.028ms        24.19%      42.321ms      42.321ms           0 b      -1.71 Mb             1  
                aten::zero_         0.00%       1.000us         0.00%       1.000us       1.000us           0 b           0 b             1  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 174.937ms

STAGE:2024-05-19 15:39:44 13888:13888 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-19 15:39:44 13888:13888 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-19 15:39:44 13888:13888 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             aten::matmul         3.43%      39.000us       166.55%       1.892ms     630.667us       2.17 Mb      -1.86 Mb             3  
            aten::reshape         2.82%      32.000us        37.59%     427.000us      85.400us       1.86 Mb           0 b             5  
              aten::clone         1.32%      15.000us        33.80%     384.000us     192.000us       1.86 Mb           0 b             2  
         aten::empty_like         0.62%       7.000us         3.52%      40.000us      20.000us       1.86 Mb           0 b             2  
              aten::empty         2.90%      33.000us         2.90%      33.000us      16.500us       1.86 Mb       1.86 Mb             2  
                aten::bmm        13.20%     150.000us        64.00%     727.000us     363.500us       1.45 Mb     742.19 Kb             2  
              original_fn        19.81%     225.000us        98.86%       1.123ms       1.123ms     781.25 Kb     -80.08 Kb             1  
                 aten::mm         0.18%       2.000us        57.75%     656.000us     656.000us     742.19 Kb           0 b             1  
                aten::max        13.38%     152.000us        14.52%     165.000us     165.000us     117.19 Kb     117.19 Kb             1  
                aten::sum         3.26%      37.000us         3.52%      40.000us      40.000us       1.95 Kb       1.95 Kb             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.136ms

EDIT: So, I’ve just realized within my example, I never fully materialize the x matrix and so the method should actually mitigate the memory cost of your einsum operation. As the example code simply performs a for-loop over the individual tensors, tensor_a and tensor_b. The additional change that I made is to simply vectorize the for-loops with torch.func.vmap and then just chunk the vmap operation to abide by memory constraints. You should be able to for-loop over without fully materializing the x tensor in your example, if you want to keep the for-loop based approach.

@ptrblck Is it possible to use gradient checkpointing here?

@AlphaBetaGamma96 I am not sure, that it will work in a correct way with gradients on backward pass

@mY_NAme It should do, run both the original version and my version and see if you get the same result. I don’t see why using torch.func.vmap should affect the gradients on the backward pass.

@AlphaBetaGamma96 Because of you never fully materialize the x tensor, you will never have a fully 4d tensor to run backward in a correct way. You anyway need a full 4d tensor on backward

So what exactly is the derivative you’re computing, results with respect to what?

@AlphaBetaGamma96 It should be results with respect to inputs. but because of inputs will be converted to 4d tensor and then I’ll take maximums and sums, that is why result will be with respect to sums of maximums, but sums of maximums will be with respects to 4d tensor, so looks like I need this 4d tensor on backward pass

In that case, you could take my vectorized approach and pass torch.func.grad(_max_fn) (with respect to the desired inputs) instead of _max_fn and that will give your gradients. Although, you’ll have to modify how you collect your gradients terms.

@AlphaBetaGamma96 Sounds hard not to make a mistake with this approach)

Well, give it a go and see if it works by comparing it to the original version. If it doesn’t work, then I can look for a different approach.

Unsure, if gradient checkpointing would help here, but maybe in a similar way you could try to offload the intermediate activations to the host, in case these are causing the large memory usage.

You can also generally try to apply torch.compile on the nested loop approach to check if it could get rid of the loops.

@ptrblck But why gradient checkpointing can’t help here. Let’s say that in default setup I have inputs as these 3d tensors, then I got 4d tensor by einsum and then reducing it to 2d with sum of maximums and calculating loss. But what if I just will checkpoint everything before einsum, then calculate einsum of first k batch of first tensor, like torch.einsum('ijk, mnk → ijmn, first_tensor[:k], second_tensor), where k ~ first_tensor.shape[0] / 10, then calculate sum of maximums of these small tensor, calculate loss, remove this small 4d tensor from GPU memory, make same actions with second batch of first tensor, third batch and so on. The main thing is that we remove all the 4d tensors from GPU memory. And on the backward we will make forward from checkpoint and calculate 4d tensor again, but this 4d tensor is a much smaller tensor because of batching, and on every batch we will calculate this small 4d tensor again for calculating gradients, and then remove it from GPU memory again. In my mind it works well. Of course, it works slower because of checkpointing, but it should use much less memory than default einsum with big tensors, isn’t it?