Matrix multiplication of big matrices - how to batch to avoid memory issues

I have two matrices, A of size [1000000, 1024], B of size [50000,1024] which I want to multiply to get [1000000,50000] matrix.

For more context 1024 are features and the other dim are samples, I want to get distance between my generated samples and training samples.

While using, B.T) I get memory allocation issues (on CPU and GPU it takes wants to allocate 200GB!) as A is a huge matrix with milion samples. I wonder how to get around that and what can I do. Using batches is an option but Im not sure how to utilize bmm function. Is this actually a solution if I cannot fit the output [1000000,50000] matrix into memory?

Hi @Alicja_K,

After a quick calculation at torch.float32 precision you do indeed at memory requirements of 200GB,

x = torch.randn(1, dtype=torch.float32)
total_memory = x.element_size() * (1000000 * 50000)
print(f"{total_memory:4.2e}") # return 2.00e+11 (or 200GB)

The only way you could do this (at least I think), is to physically have 200GB of RAM. Now you could batch this operation to generate a subset of this (1e6 x 5e4) matrix, but then you’d only have a sub-set to calculate distances between your generated samples and training samples.

Do you have a minimal reproducible example of what you want to do exactly?

Also, if you’re going to perform further matrix operations on this (1e6 x 5e4) matrix, you may be able to exploit some algebraic associativity and avoid computing the whole matrix, but still use all of its information. An example of this in PyTorch can be seen in the docs for torch.linalg.multi_dot, the docs are here: torch.linalg.multi_dot — PyTorch 2.0 documentation

Yess, then even using float16 is not enough. Thanks for confirming that this huge number is actually true and not some buggy allocation.

Top 1 value (max) of the distance (here rather similarity) per sample I use mostly and from there carry the calculations about the distribution of closest samples. Then I also use top 3 per sample to display examples of similarities. So this seems like an easy thing as I don’t need the whole matrix in the end and can rather keep track of this top 3 max values.

I guess roughly it would look like:

simscores = []
for a_batch in next(iter(a_features)): #lets say a_batch is [10000,1024]
    sim =, b.T) # sim: [10000,50000]
    simtop3 = sim.topk(3, largest=True) # I think [10000, 3, 2] 2 bc it also returns index 

# reshape simscores to [1000000,3,2]

Any feedback on this?

Hi @Alicja_K,

That could be a potential way to mitigate the memory problem as the sim only requires 2GB of RAM, and you’d simply trade RAM for walltime.

Also, why do you have different size batches for the generated samples and training samples? If they were the same size (and each sample in both datasets represent the training & generated example that predicts the training example), you could do a simple cosine similarity, which is essentially a magnitude normalized dot-product (and doesn’t require a huge matrix), but again I don’t know what your specific problem is, I’m just having an educated guess.

Simply because there is no “prediction from” training. I’m working with generated images by noise-to-image diffusion model and testing 1mil generated images and their proximity to training set. Here [100000,1024] second dimension are the features extracted from each image. Hence I need to compare every generated image to every training image unfortunately.