Accessing all matrix multiplication done by pytorch

Perhaps a vague and ambitious question, but for a research project, I am effectively looking to access all matrix multiplication done by a pytorch model under the hood and store it in memory. The use case is to run my own simulations with these matrices.

In the ideal case, I would be able to generate a python list of all matrix multiplication done by a particular pytorch model (in both training and inference) after running it, along with the order in which they occur. How would I begin to go about doing something like this? Any suggestions would be much appreciated!

Would “access all matrix multiplication” mean you would like to store the outputs of these matmuls or somehow the calls?

Both the inputs and outputs, and the sequence in which they are multiplied (if possible). This would be sufficient for now.

To motivate this question, in the long run I will be attempting to perform all matrix multiplication on a processor currently inaccessible by pytorch. So what I’d like to do is take any two matrices that pytorch multiplies, multiply them myself (rather than through GPU as pytorch would usually do) and reroute the result back into the model for further training/inference. So I suppose for this purpose, access to the calls would be necessary.

Any direction or advice would be greatly appreciated. Thank you for your help!

Maybe overriding __torch_function__ as described here could work.

1 Like