Block matrix multiplication (memory error)


I have a block diagonal matrix A = [ A_1, 0, 0; 0, A_2, 0; 0, 0, A_3] I am multiplying it with my input vector X = [ X_1; X_2; X_3], and the output is Y = [Y_1; Y_2; Y_3]. While training my neural net it seems like during backward pass pytorch is trying to
allocate a huge amount of memory and throwing the error:

"RuntimeError: CUDA out of memory. Tried to allocate 2875.53 GiB (GPU 0; 3.94 GB total capacity; 1.31 GiB already allocated; 1.68 GiB free; 1.44 GiB reserved in total by PyTorch)"

However, if I use a for loop I am not facing any problem but the operation is very slow. I think this problem is happening because during back ward pass PyTorch is trying to find d(Y_1)/d(A_1), d(Y_1)/d(A_2), d(Y_1)/d(A_3), though Y_1 depends only on A_1. Is there any way to solve this problem.

Any help is really appreciated!

Thank you!