Forward pass hook to get weights multiplied by input?

Given an arbitrary model, I’d like to get a matrix of weights*inputs. I do not want the activations/outputs. I want a matrix of values where an entry is something like the ith weight multiplied by the jth input: w_i * input_j

How can I make a forward hook to do this?

Could you post a code snippet (with nested for loops if needed) showing the desired output using a plain nn.Linear module and a random input?

I’d like something like this:

activation = {}
def get_activation(name, current_linear_module):
    def hook(model, input, output):
        activation[name] = current_linear_module.weight * input
    return hook

It should work for a linear layer:

current_linear_module= nn.Linear(5, 2)
input = torch.randn(1, 5)

But how do I get this to work with an arbitrary module? Like a CNN or Transformer? They’re still implemented using matrix multiplication, so how do I grab this intermediate result of an individual neuron connection * input scalar?

For the CNN, worst case scenario I’d need to do something like this: Is there an function in PyTorch for converting convolutions to fully-connected networks form? - Stack Overflow However, it would be impossible for me to implement this for a bunch of different pytorch modules. Pytorch should’ve already implemented all of this into matrix multiplications

I still don’t fully understand the use case as it seems you want to recompute the output value without the bias. Wouldn’t subtracting the bias tensor be easier in this case?
If you want to manually recompute the output with a matrix multiplication for some reason, you might indeed need to reimplement this approach for all modules since e.g. conv layers would not always use matmuls depending on the availability of 3rd party libs and kernels.

I need the intermediate elementwise multiplication: weight * value. I don’t want the output, which is the dot product between all rows of the weight matrix and the vector input: W*v.

To clarify, we would gather the output by elementwise multiplying first and then accumulate into a single scalar. Is it possible to get the result before the accumulate?