Hi! I have a very simple set of embeddings of size `|x| = [b, p, dim]`

, where along each axis of `p`

, I applying some form of scaling to the `dim`

axis. In the original setting, matrix `x`

is supposed to be multiplied by matrix `z`

, where `|z| = [b, p, dim]`

. This gives us `|original_result| = [b, p, p]`

.

Now, I wish to multiply this by a diagonal matrix `y`

for every `p`

, where the diagonal is of size `(dim, dim)`

. This matrix `y`

also tracks gradient operations. The size of `y`

is `|y| = [b, p, dim]`

As such, I have been doing it this way:

```
x = x.unsqueeze(-2) #(b, p, 1, dim)
y = linear(x) # (b, p, dim)
y = torch.diag_embed(y) # (b, p, dim, dim)
result = torch.matmul(x, y) # (b, p, 1, dim)
new_result = torch.matmul(result.squeeze(-2), z) # (b, p, n)
```

Now, from what I have noticed, this multiplication makes the GPU memory explode. Without producing `y`

and doing any matmul, the GPU usage is about 5GB. However, once I introduce `y`

and do the operations, the GPU usage hits 21GB.

I am wondering how we can make this operation smoother. Are there more efficient operations that I can take advantage of so as to reduce the computing cost?