Hello! If I have two tensors of shape [200, 2000, 256] and [250, 1900, 256] (like batch x seq_len x emb_size), and I wanna get multiplication of every embedding in first tensor with every embedding in second tensor, then take max of last dimension and then take sum of last dimension, I mean something like that: “”“torch.einsum('ijk, mnk → ijmn, first_tensor, second_tensor).max(-1).values.sum(-1)”“”, does here exist any method to implement it as fast as einsum (without loops), but with less memory than einsum (because of large dimensions of tensors, 4d tensor is large even without gradients and optimizators)?
Could you post the (slow) PyTorch reference code using loops here, please?
You could have a look at the torch.func.vmap
function to vectorize over your batch dimension, although you have two different batch sizes of 200 in the first tensor and 250 in the second tensor and that might complicate things.
As @ptrblck stated, share a minimal reproducible example of your problem.
Thank you for your reply. For example, this one:
import torch
tensor_a = torch.randn(300, 2000, 256)
tensor_b = torch.randn(315, 1990, 256)
result = torch.zeros(300, 315)
x = torch.einsum('ijk, mnk -> ijmn', tensor_a, tensor_b).max(-1).values.sum(1)
for i in range(tensor_a.shape[0]):
for j in range(tensor_b.shape[0]):
result[i, j] = (tensor_a[i] @ tensor_b[j].T).max(-1).values.sum(-1)
print(result)
print(x)
The problem of einsum is that it firstly creates 4d tensor of shape 300x2000x315x1990 and if takes a lot of memory with fp32 even without gradients and optimisers
So, I’ve had a look at this problem and I’m not 100% sure if there’s a way to mitigate the memory issue. I don’t think there’s an associativity trick that can be exploited to mitigate the memory issue (at least that I’m aware of).
One thing that might help performance (at least in terms of walltime), is to vectorize the operation and just ‘chunk’ the amount of data that goes through torch.func.vmap
to mitigate memory cost and trade off between walltime and memory constraints.
A minimal reproducible example can be found below (for smaller inputs tensors, tensor_a
and tensor_b
). You can change the mini-batch size (or chunk_size
) with the chunk_size_a
and chunk_size_b
inside the vmap_fn
import torch
from torch import Tensor
batch_a = 20#0
batch_b = 25#0
seq_len_a = 20#00
seq_len_b = 19#00
emb_size = 25#6
tensor_a = torch.randn(batch_a, seq_len_a, emb_size)
tensor_b = torch.randn(batch_b, seq_len_b, emb_size)
def original_fn(tensor_a: Tensor, tensor_b: Tensor) -> Tensor:
result = torch.zeros(batch_a, batch_b)
x = torch.einsum('ijk, mnk -> ijmn', tensor_a, tensor_b).max(-1).values.sum(1)
for i in range(tensor_a.shape[0]):
for j in range(tensor_b.shape[0]):
result[i, j] = (tensor_a[i] @ tensor_b[j].T).max(-1).values.sum(-1)
return result
from torch.func import vmap
def vmap_fn(tensor_a: Tensor, tensor_b: Tensor) -> Tensor:
chunk_size_a = None #None equal to full-batch vmap
chunk_size_b = None
def _max_fn(a, b):
return torch.max(a @ b.T, dim=-1).values.sum(dim=-1)
result = vmap(
vmap(
_max_fn, in_dims=(None,0), chunk_size=chunk_size_a
), in_dims=(0,None), chunk_size=chunk_size_b
)(tensor_a, tensor_b)
return result
result_og = original_fn(tensor_a, tensor_b)
result_vmap = vmap_fn(tensor_a, tensor_b)
print(torch.allclose(result_og, result_vmap)) #returns True
import timeit
t0 = timeit.Timer(
stmt='original_fn(tensor_a, tensor_b)',
setup='from __main__ import original_fn',
globals={'tensor_a':tensor_a, 'tensor_b':tensor_b})
t1 = timeit.Timer(
stmt='vmap_fn(tensor_a, tensor_b)',
setup='from __main__ import vmap_fn',
globals={'tensor_a':tensor_a, 'tensor_b':tensor_b})
original_time = t0.timeit(100) / 100 * 1e6
vmap_time = t1.timeit(100) / 100 * 1e6
print(f'original_fn(a, b): {original_time:>5.1f} us')
print(f'vmap_fn(a, b): {vmap_time:>5.1f} us')
print(f'vmap_fn(a, b)/original_fn(a, b) ratio: {original_time/vmap_time:.3f} times faster')
from torch.profiler import profile, record_function, ProfilerActivity
#cuda_time_total, cpu_time_total
#cuda_memory_usage, cpu_memory_usage
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
with record_function("original_fn"):
original_fn(tensor_a, tensor_b)
print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
with record_function("vmap_fn"):
vmap_fn(tensor_a, tensor_b)
print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))
This returns,
True
original_fn(a, b): 17656.7 us
vmap_fn(a, b): 536.8 us
vmap_fn(a, b)/original_fn(a, b) ratio: 32.892 times faster
STAGE:2024-05-19 15:39:42 13888:13888 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-19 15:39:43 13888:13888 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-19 15:39:43 13888:13888 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::einsum 0.03% 48.000us 0.12% 214.000us 214.000us 742.19 Kb 0 b 1
aten::bmm 0.07% 115.000us 0.07% 115.000us 115.000us 742.19 Kb 742.19 Kb 1
aten::matmul 0.64% 1.120ms 2.78% 4.866ms 9.732us 742.19 Kb 29.69 Kb 500
aten::mm 2.23% 3.898ms 2.23% 3.898ms 7.796us 742.19 Kb 742.19 Kb 500
aten::max 3.72% 6.502ms 5.80% 10.150ms 20.259us 234.38 Kb 234.38 Kb 501
aten::sum 1.88% 3.295ms 2.09% 3.664ms 7.313us 3.91 Kb 3.91 Kb 501
aten::zeros 0.01% 14.000us 0.01% 23.000us 23.000us 1.95 Kb 0 b 1
aten::empty 0.01% 9.000us 0.01% 9.000us 9.000us 1.95 Kb 1.95 Kb 1
original_fn 8.59% 15.028ms 24.19% 42.321ms 42.321ms 0 b -1.71 Mb 1
aten::zero_ 0.00% 1.000us 0.00% 1.000us 1.000us 0 b 0 b 1
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 174.937ms
STAGE:2024-05-19 15:39:44 13888:13888 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-05-19 15:39:44 13888:13888 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-19 15:39:44 13888:13888 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::matmul 3.43% 39.000us 166.55% 1.892ms 630.667us 2.17 Mb -1.86 Mb 3
aten::reshape 2.82% 32.000us 37.59% 427.000us 85.400us 1.86 Mb 0 b 5
aten::clone 1.32% 15.000us 33.80% 384.000us 192.000us 1.86 Mb 0 b 2
aten::empty_like 0.62% 7.000us 3.52% 40.000us 20.000us 1.86 Mb 0 b 2
aten::empty 2.90% 33.000us 2.90% 33.000us 16.500us 1.86 Mb 1.86 Mb 2
aten::bmm 13.20% 150.000us 64.00% 727.000us 363.500us 1.45 Mb 742.19 Kb 2
original_fn 19.81% 225.000us 98.86% 1.123ms 1.123ms 781.25 Kb -80.08 Kb 1
aten::mm 0.18% 2.000us 57.75% 656.000us 656.000us 742.19 Kb 0 b 1
aten::max 13.38% 152.000us 14.52% 165.000us 165.000us 117.19 Kb 117.19 Kb 1
aten::sum 3.26% 37.000us 3.52% 40.000us 40.000us 1.95 Kb 1.95 Kb 1
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 1.136ms
EDIT: So, I’ve just realized within my example, I never fully materialize the x
matrix and so the method should actually mitigate the memory cost of your einsum operation. As the example code simply performs a for-loop over the individual tensors, tensor_a
and tensor_b
. The additional change that I made is to simply vectorize the for-loops with torch.func.vmap
and then just chunk the vmap operation to abide by memory constraints. You should be able to for-loop over without fully materializing the x
tensor in your example, if you want to keep the for-loop based approach.
@ptrblck Is it possible to use gradient checkpointing here?
@AlphaBetaGamma96 I am not sure, that it will work in a correct way with gradients on backward pass
@mY_NAme It should do, run both the original version and my version and see if you get the same result. I don’t see why using torch.func.vmap
should affect the gradients on the backward pass.
@AlphaBetaGamma96 Because of you never fully materialize the x tensor, you will never have a fully 4d tensor to run backward in a correct way. You anyway need a full 4d tensor on backward
So what exactly is the derivative you’re computing, results
with respect to what?
@AlphaBetaGamma96 It should be results with respect to inputs. but because of inputs will be converted to 4d tensor and then I’ll take maximums and sums, that is why result will be with respect to sums of maximums, but sums of maximums will be with respects to 4d tensor, so looks like I need this 4d tensor on backward pass
In that case, you could take my vectorized approach and pass torch.func.grad(_max_fn)
(with respect to the desired inputs) instead of _max_fn
and that will give your gradients. Although, you’ll have to modify how you collect your gradients terms.
@AlphaBetaGamma96 Sounds hard not to make a mistake with this approach)
Well, give it a go and see if it works by comparing it to the original version. If it doesn’t work, then I can look for a different approach.
Unsure, if gradient checkpointing would help here, but maybe in a similar way you could try to offload the intermediate activations to the host, in case these are causing the large memory usage.
You can also generally try to apply torch.compile
on the nested loop approach to check if it could get rid of the loops.
@ptrblck But why gradient checkpointing can’t help here. Let’s say that in default setup I have inputs as these 3d tensors, then I got 4d tensor by einsum and then reducing it to 2d with sum of maximums and calculating loss. But what if I just will checkpoint everything before einsum, then calculate einsum of first k batch of first tensor, like torch.einsum('ijk, mnk → ijmn, first_tensor[:k], second_tensor)
, where k ~ first_tensor.shape[0] / 10
, then calculate sum of maximums of these small tensor, calculate loss, remove this small 4d tensor from GPU memory, make same actions with second batch of first tensor, third batch and so on. The main thing is that we remove all the 4d tensors from GPU memory. And on the backward we will make forward from checkpoint and calculate 4d tensor again, but this 4d tensor is a much smaller tensor because of batching, and on every batch we will calculate this small 4d tensor again for calculating gradients, and then remove it from GPU memory again. In my mind it works well. Of course, it works slower because of checkpointing, but it should use much less memory than default einsum
with big tensors, isn’t it?