Memory efficient multiple matrix multiplication

Hello everyone,

I have a forward of my model that works, but for a bit bigger input data it goes out of GPU memory, the main part is about multiplying a long sequence of matrix in a numerical stable way, and i use a lot of broadcasting also, but probably I am doing a poor memory management and if anyone has an idea on how to make it more memory efficient it would be of great help.

A summary on what the forward should do

Basically you want to compute the following expression:
p_0 D_0 P_1 D_1 … D_{n-1}P_n x_n

Where it is all matrix multiplication and p_0 and x_n are vectors.

D_i is a diagonal matrix where the elements in the diagonal are the value of x_i (a vector)

And P_i is a matrix with the conditional probabilities

The forward first compute all together D_{i-1}P_i = B_i when I define for the first time middle in the forward function

Then we have

p_0 B_1 … B_n x_n

And the multiplication of matrix B’s is done in couples in the matmul function defined in the forward. And at last the product with p_0 and x_n

the dimensions of the tensor used are:
X.shape = [n,t,v,c]

Self.bulk.shape = [1,t,v,c,c]

Self.p0.shape = [1,t,c]

where
-n is the batch size,
-t is alwys equal to 1
-v the variables (number of x_i)
-c a hidden dimension

n,t,v,c = 10, 1, 6, 40
X = torch.randn(n,t,v,c) - 2
bulk = torch.rand(1,t,v-1,c,c)
p0 = torch.rand(1, t, c)

def forward(x: torch.Tensor, bulk, p0, T, C):
    
    x0 = x[...,:-1, :, None]

    x1 = x[...,-1, None, :]

    middle = x0 + bulk.log()

    def mulstack(stack):
            
        if stack.shape[-3] % 2 == 1:

            stack = torch.cat([stack, torch.zeros(stack.shape[0], T, 1,C,C).to(stack.device)], dim=-3)
            
        stck1 = stack[...,0::2,:, :]
        stck2 = stack[...,1::2,:, :]

        b1 = torch.max(stck1, dim=-1, keepdim=True)[0]
        b2 = torch.max(stck2, dim=-2, keepdim=True)[0]

        stck1 = stck1 - b1
        stck1 = stck1.exp()
        stck1 = torch.clamp(stck1, 1e-10)

        stck2 = stck2 - b2
        stck2 = stck2.exp()
        stck2 = torch.clamp(stck2, 1e-10)

        mul = stck1 @ stck2

        b = b1 + b2

        mul = mul.log()

        return b + mul
    
    while middle.shape[-3] > 1:

        middle = mulstack(middle)

    middle = torch.logsumexp(middle.squeeze(-3) + x1, dim=-1)

    result = torch.logsumexp(p0.log() + middle, dim=(-1))

    return result 


forward(X, bulk, p0, t, c)

and with
n, t, v, c = 500, 1, 40, 200 I get:

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.57 GiB. GPU 0 has a total capacity of 10.75 GiB of which 894.69 MiB is free. Process 1813936 has 1.47 GiB memory in use. Process 1813962 has 5.01 GiB memory in use. Including non-PyTorch memory, this process has 3.40 GiB memory in use. Of the allocated memory 3.21 GiB is allocated by PyTorch, and 18.03 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (CUDA semantics — PyTorch 2.4 documentation)

1 Like

I am getting an error with that implementation but not with pad. Before I proceed could you verify my code once?

def forward_einsum_pad(self, x: torch.Tensor):
    middle = x[...,:-1, :, None] + self.bulk.log()

    def mulstack(stack):
        if stack.shape[-3] % 2 == 1:
            # Adjust the padding tensor to match the dimensions of `stack`
            padding_shape = list(stack.shape)
            padding_shape[-3] = 1  # Add one to the third-to-last dimension
            padding = torch.zeros(padding_shape).to(stack.device)
            stack = torch.cat([stack, padding], dim=-3)

        b1 = torch.max(stack[...,0::2,:, :], dim=-1, keepdim=True)[0]
        b2 = torch.max(stack[...,1::2,:, :], dim=-2, keepdim=True)[0]
        stck1 = torch.clamp(torch.exp(stack[...,0::2,:, :] - b1), self.epsilon)
        stck2 = torch.clamp(torch.exp(stack[...,1::2,:, :] - b2), self.epsilon)

        mul = torch.log(stck1 @ stck2)
        b = b1 + b2

        return b + mul

    while middle.shape[-3] > 1:
        middle = mulstack(middle)

    middle = torch.logsumexp(middle.squeeze(-3) + x[...,-1, None, :], dim=-1)
    result = torch.logsumexp(self.p0.log() + middle, dim=(-1))

    return result

Link to my minimal testing: Google Colab

1 Like

Im not sure the forward seems right but looking at the mock class you built I think the bulk and the input have the wrong dimension, i dont know if that is causing the problem since the x should be
(Batchsize, T, Variables, C) the bulk (1, T, Variables, C,C) and p0 (1,T,C) this because we cycle throught the variables and split them to do matrix mul for couples of variables and all batch all togheter,

Could you possibly share a small end-to-end implementation then?

I edited it in the original post :slight_smile: