I’m developing TorchUncertainty, an open-source library to ease the use of state-of-the-art methods, increasing the reliability of modern Deep Neural Networks.
For this project, I am trying to optimize the memory usage and inference/training speed of a particular case of torch.nn.Linear, with block diagonal weights, but I can’t achieve a stable speedup over a for loop of smaller nn.Linear despite my numerous attempts. For information, I would like the solution to work both for inference and training, and I’m OK with working with torch 2.0. My experiments mainly use AMP fp16.
Let me detail the problem. In the simplest case, I have a dataloader that provides samples of shape (B, M × F) with B the batch size (classically 128) and F the number of features (a power of 2 between 64 and 4096, for instance). M is an integer that varies from 1 to 16.
The specificity is that I want the weights of the F.linear to be block diagonal, with M blocks of the same size (and height F).
I have tried several solutions:
- masking: I can mask the parameters that are not block diagonal, but I don’t benefit from the structured sparsity, which is costly in memory and time.
- for loop: I can run the linear operations in a for loop on slices of the input matrix. This works well and seems memory efficient, but it is slow. I tried to “vmap” the for loop, but I haven’t been able to make it work in training, although I gained a small speedup.
- torch.block_diag: I can create a certain number of parameters (in a 3-ndim tensor) and project them in a block diagonal matrix in the .forward(). This is more memory expensive and as slow as the for loop (for the speed, it depends a lot on the matrix size).
- grouped conv1d: I can use a grouped conv1d (M groups) with a rearrange and unsqueeze(-1), but again this is memory efficient but not faster (again it depends a lot on the matrix size).
- BSR: Finally, I also tried BSR-sparse weights - that are supported now with F.linear - but I couldn’t make them work during training, and I’m not sure that it will be super efficient in my case where the sparsity isn’t great (1/M), but the layout of the sparse block diagonal matrices is very special and simple.
I can release my small benchmark and codes if needed.
Would you have any other idea? Thanks so much in advance!