Multiplying with a diagonal matrix becomes very expensive

Hi! I have a very simple set of embeddings of size |x| = [b, p, dim], where along each axis of p, I applying some form of scaling to the dim axis. In the original setting, matrix x is supposed to be multiplied by matrix z, where |z| = [b, p, dim]. This gives us |original_result| = [b, p, p].

Now, I wish to multiply this by a diagonal matrix y for every p, where the diagonal is of size (dim, dim). This matrix y also tracks gradient operations. The size of y is |y| = [b, p, dim] As such, I have been doing it this way:

x = x.unsqueeze(-2) #(b, p, 1, dim)

y = linear(x) # (b, p, dim)

y = torch.diag_embed(y)  # (b, p, dim, dim)

result = torch.matmul(x, y)  # (b, p, 1, dim)

new_result = torch.matmul(result.squeeze(-2), z)  # (b, p, n)

Now, from what I have noticed, this multiplication makes the GPU memory explode. Without producing y and doing any matmul, the GPU usage is about 5GB. However, once I introduce y and do the operations, the GPU usage hits 21GB.

I am wondering how we can make this operation smoother. Are there more efficient operations that I can take advantage of so as to reduce the computing cost?

Hi Legoh!

Based on the code fragment you posted, it appears that you are performing
element-wise multiplication of two tensors in a roundabout way.

If so, there is no need for the (potentially large amount of) memory required
to materialize the full [b, p, dim, dim] result of diag_embed (y).

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> b = 2
>>> p = 2
>>> dim = 3
>>>
>>> x = torch.randn (b, p, dim)     # (b, p, dim)
>>> u = x                           # save a reference to original x
>>> x = x.unsqueeze (-2)            # (b, p, 1, dim)
>>> y = torch.randn (b, p, dim)     # (b, p, dim)
>>> v = y                           # save a reference to original y
>>> y = torch.diag_embed (y)        # (b, p, dim, dim)
>>>
>>> result = torch.matmul(x, y)     # (b, p, 1, dim)
>>> result = result.squeeze (-2)    # (b, p, dim)
>>>
>>> resultB = u * v                 # element-wise multiplication of original x and y
>>>
>>> torch.equal (result, resultB)   # same result
True

Best.

K. Frank

Thank you! I completely forgot that it was the case. This seems to have solved the memory issues in a big way, thanks!