Hi! I have a very simple set of embeddings of size |x| = [b, p, dim]
, where along each axis of p
, I applying some form of scaling to the dim
axis. In the original setting, matrix x
is supposed to be multiplied by matrix z
, where |z| = [b, p, dim]
. This gives us |original_result| = [b, p, p]
.
Now, I wish to multiply this by a diagonal matrix y
for every p
, where the diagonal is of size (dim, dim)
. This matrix y
also tracks gradient operations. The size of y
is |y| = [b, p, dim]
As such, I have been doing it this way:
x = x.unsqueeze(-2) #(b, p, 1, dim)
y = linear(x) # (b, p, dim)
y = torch.diag_embed(y) # (b, p, dim, dim)
result = torch.matmul(x, y) # (b, p, 1, dim)
new_result = torch.matmul(result.squeeze(-2), z) # (b, p, n)
Now, from what I have noticed, this multiplication makes the GPU memory explode. Without producing y
and doing any matmul, the GPU usage is about 5GB. However, once I introduce y
and do the operations, the GPU usage hits 21GB.
I am wondering how we can make this operation smoother. Are there more efficient operations that I can take advantage of so as to reduce the computing cost?