Initialization of a layer with different function than its forward function

I have a layer in my network, suppose it is a matrix that inputs x and outputs y=Ax.
Now, the matrix A is just a placeholder, it is defined based on a function f(x) where x is the input (i.e A=f(x)). So we can say y=Ax=f(x)*x. Trainable parameters are parameters inside f, not A.
Implementing such network is quite easy, the problem is that I want to initialize the placeholder, A, in another fashion, meaning that A=g(x) in initialization, where g is different than f.

I’m not sure how to properly implement this. Using a flag for initialization feels hacky and doesn’t seem like true initialization—it just switches between two computations. Additionally, I’m concerned that this kind of network might not be fully end-to-end differentiable.

Does this make sense, and is it possible to implement this in a principled way in PyTorch? Let me know if anything is unclear or if you have questions about my explanation.

This statement is a bit unclear since you are claiming the layer contains the matrix A, which sounds as if A is a trainable nn.Parameter, while later you are explaining A is computed from the input via A = f(x). In this case, the layer would be a pure functional layer and would not contain any trainable parameters (which is fine). If so, you could indeed add some logic to the forward method to switch between the different A computations.

I agree, sorry for the confusion. I edited the question for better clarity.
The matrix itself is not trainable, the parameters inside the function f are trainable parameters, matrix A is just a placeholder for function f, during training, A is filled with f(x), but in initializaiton, A is based on g(x).