Hello there,

I have a lot of linear layers (up to 12000) with variable input (1-100) and an output of 1. I currently see two ways to tackle this:

- Loop through each layer (Low memory consumption ,slow)
- Merge them to a big layer using
`torch.block_diag`

on all weight matrices (High memory consumption, fast inference but slow when calculating gradients) - Find a middle ground and combine some of the layers

The problem with 2 is that I cant directly create a sparse matrix and so I have to first create this big > 100000x100000 matrix with a lot of zeros.

Does anyone here know of a more efficient/better solution for this problem?

Thanks