Hello there,
I have a lot of linear layers (up to 12000) with variable input (1-100) and an output of 1. I currently see two ways to tackle this:
- Loop through each layer (Low memory consumption ,slow)
- Merge them to a big layer using
torch.block_diag
on all weight matrices (High memory consumption, fast inference but slow when calculating gradients) - Find a middle ground and combine some of the layers
The problem with 2 is that I cant directly create a sparse matrix and so I have to first create this big > 100000x100000 matrix with a lot of zeros.
Does anyone here know of a more efficient/better solution for this problem?
Thanks