Hello everyone,

I have a forward of my model that works, but for a bit bigger input data it goes out of GPU memory, the main part is about multiplying a long sequence of matrix in a numerical stable way, and i use a lot of broadcasting also, but probably I am doing a poor memory management and if anyone has an idea on how to make it more memory efficient it would be of great help.

A summary on what the forward should do

Basically you want to compute the following expression:

p_0 D_0 P_1 D_1 … D_{n-1}P_n x_n

Where it is all matrix multiplication and p_0 and x_n are vectors.

D_i is a diagonal matrix where the elements in the diagonal are the value of x_i (a vector)

And P_i is a matrix with the conditional probabilities

The forward first compute all together D_{i-1}P_i = B_i when I define for the first time middle in the forward function

Then we have

p_0 B_1 … B_n x_n

And the multiplication of matrix B’s is done in couples in the matmul function defined in the forward. And at last the product with p_0 and x_n

the dimensions of the tensor used are:

X.shape = [n,t,v,c]

Self.bulk.shape = [1,t,v,c,c]

Self.p0.shape = [1,t,c]

where

-n is the batch size,

-t is alwys equal to 1

-v the variables (number of x_i)

-c a hidden dimension

```
n,t,v,c = 10, 1, 6, 40
X = torch.randn(n,t,v,c) - 2
bulk = torch.rand(1,t,v-1,c,c)
p0 = torch.rand(1, t, c)
def forward(x: torch.Tensor, bulk, p0, T, C):
x0 = x[...,:-1, :, None]
x1 = x[...,-1, None, :]
middle = x0 + bulk.log()
def mulstack(stack):
if stack.shape[-3] % 2 == 1:
stack = torch.cat([stack, torch.zeros(stack.shape[0], T, 1,C,C).to(stack.device)], dim=-3)
stck1 = stack[...,0::2,:, :]
stck2 = stack[...,1::2,:, :]
b1 = torch.max(stck1, dim=-1, keepdim=True)[0]
b2 = torch.max(stck2, dim=-2, keepdim=True)[0]
stck1 = stck1 - b1
stck1 = stck1.exp()
stck1 = torch.clamp(stck1, 1e-10)
stck2 = stck2 - b2
stck2 = stck2.exp()
stck2 = torch.clamp(stck2, 1e-10)
mul = stck1 @ stck2
b = b1 + b2
mul = mul.log()
return b + mul
while middle.shape[-3] > 1:
middle = mulstack(middle)
middle = torch.logsumexp(middle.squeeze(-3) + x1, dim=-1)
result = torch.logsumexp(p0.log() + middle, dim=(-1))
return result
forward(X, bulk, p0, t, c)
```

and with

n, t, v, c = 500, 1, 40, 200 I get:

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.57 GiB. GPU 0 has a total capacity of 10.75 GiB of which 894.69 MiB is free. Process 1813936 has 1.47 GiB memory in use. Process 1813962 has 5.01 GiB memory in use. Including non-PyTorch memory, this process has 3.40 GiB memory in use. Of the allocated memory 3.21 GiB is allocated by PyTorch, and 18.03 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (CUDA semantics — PyTorch 2.4 documentation)