# Multiplying with a diagonal matrix becomes very expensive

Hi! I have a very simple set of embeddings of size `|x| = [b, p, dim]`, where along each axis of `p`, I applying some form of scaling to the `dim` axis. In the original setting, matrix `x` is supposed to be multiplied by matrix `z`, where `|z| = [b, p, dim]`. This gives us `|original_result| = [b, p, p]`.

Now, I wish to multiply this by a diagonal matrix `y` for every `p`, where the diagonal is of size `(dim, dim)`. This matrix `y` also tracks gradient operations. The size of `y` is `|y| = [b, p, dim]` As such, I have been doing it this way:

``````x = x.unsqueeze(-2) #(b, p, 1, dim)

y = linear(x) # (b, p, dim)

y = torch.diag_embed(y)  # (b, p, dim, dim)

result = torch.matmul(x, y)  # (b, p, 1, dim)

new_result = torch.matmul(result.squeeze(-2), z)  # (b, p, n)
``````

Now, from what I have noticed, this multiplication makes the GPU memory explode. Without producing `y` and doing any matmul, the GPU usage is about 5GB. However, once I introduce `y` and do the operations, the GPU usage hits 21GB.

I am wondering how we can make this operation smoother. Are there more efficient operations that I can take advantage of so as to reduce the computing cost?

Hi Legoh!

Based on the code fragment you posted, it appears that you are performing
element-wise multiplication of two tensors in a roundabout way.

If so, there is no need for the (potentially large amount of) memory required
to materialize the full `[b, p, dim, dim]` result of `diag_embed (y)`.

Consider:

``````>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> b = 2
>>> p = 2
>>> dim = 3
>>>
>>> x = torch.randn (b, p, dim)     # (b, p, dim)
>>> u = x                           # save a reference to original x
>>> x = x.unsqueeze (-2)            # (b, p, 1, dim)
>>> y = torch.randn (b, p, dim)     # (b, p, dim)
>>> v = y                           # save a reference to original y
>>> y = torch.diag_embed (y)        # (b, p, dim, dim)
>>>
>>> result = torch.matmul(x, y)     # (b, p, 1, dim)
>>> result = result.squeeze (-2)    # (b, p, dim)
>>>
>>> resultB = u * v                 # element-wise multiplication of original x and y
>>>
>>> torch.equal (result, resultB)   # same result
True
``````

Best.

K. Frank

Thank you! I completely forgot that it was the case. This seems to have solved the memory issues in a big way, thanks!