Hi all,

I have an equation like this:

**v**, **h** and **M** are learnable parameters, and the size of **v** is `n*k*1`

. And I implemented this as follows (b is the batch_size):

```
(v * x).transpose(-1, -2).sum(dim=1).matmul(( # b*1*k
(h * v).matmul( # b*n*k*1
v.transpose(-1, -2).matmul( # b*n*1*1
M.matmul(embed_deep))) # b*n*k*1
* x).sum(dim=1) # b*1*1
```

However, I found this snippet consumed large GPU memory (e.g., 20G when I set `k=512`

). In my view, this is just some matrix multiplication and summations, how could it possible to use so much GPU memories?

Any help, please.